The repository contains reproducible PyTorch source code of our paper Generative Modeling with Optimal Transport Maps, ICLR 2022.

Overview

Generative Modeling with Optimal Transport Maps

The repository contains reproducible PyTorch source code of our paper Generative Modeling with Optimal Transport Maps, ICLR 2022. It focuses on Optimal Transport Modeling (OTM) in ambient space, e.g. spaces of high-dimensional images. While analogous approaches consider OT maps in the latent space of an autoencoder, this paper focuses on fitting an OT map directly between noise and ambient space. The method is evaluated on generative modeling and unpaired image restoration tasks. In particular, large-scale computer vision problems, such as denoising, colorization, and inpainting are considered in unpaired image restoration. The overall pipeline of OT as generative map and OT as cost of generative model is given below.

Latent Space Optimal Transport

Our method is different from the prevalent approach of OT in the latent space shown below.

Ambient Space Mass Transport

The scheme of our approach for learning OT maps between unequal dimensions.

Prerequisites

The implementation is GPU-based. Single GPU (V100) is enough to run each experiment. Tested with torch==1.4.0 torchvision==0.5.0. To reproduce the reported results, consider using the exact version of PyTorch and its required dependencies as other versions might be incompatible.

Repository structure

All the experiments are issued in the form of pretty self-explanatory python codes.

Main Experiments

Execute the following commands in the source folder.

Training

  • python otm_mnist_32x22.py --train 1 -- OTM between noise and MNIST, 32x32, Grayscale;
  • python otm_cifar_32x32.py --train 1 -- OTM between noise and CIFAR10, 32x32, RGB;
  • python otm_celeba_64x64.py --train 1 -- OTM between noise and CelebA, 64x64, RGB;
  • python otm_celeba_denoise_64x64.py --train 1 -- OTM for unpaired denoising on CelebA, 64x64, RGB;
  • python otm_celeba_colorization_64x64.py --train 1 -- OTM for unpaired colorization on CelebA, 64x64, RGB;
  • python otm_celeba_inpaint_64x64.py --train 1 -- OTM unpaired inpainting on CelebA, 64x64, RGB.

Run inference with the best iteration.

Inference

  • python otm_mnist_32x32.py --inference 1 --init_iter 100000
  • python otm_cifar_32x32.py --inference 1 --init_iter 100000
  • python otm_celeba_64x64.py --inference 1 --init_iter 100000
  • python otm_celeba_denoise_64x64.py --inference 1 --init_iter 100000
  • python otm_celeba_colorization_64x64.py --inference 1 --init_iter 100000
  • python otm_celeba_inpaint_64x64.py --inference 1 --init_iter 100000

Toy Experiments in 2D

  • source/toy/OTM-GO MoG.ipynb -- Mixture of 8 Gaussians;
  • source/toy/OTM-GO Moons.ipynb -- Two Moons;
  • source/toy/OTM-GO Concentric Circles.ipynb -- Concentric Circles;
  • source/toy/OTM-GO S Curve.ipynb -- S Curve;
  • source/toy/OTM-GO Swirl.ipynb -- Swirl.

Refer to Credit Section for baselines.

Results

Optimal transport modeling between high-dimensional noise and ambient space.

Randomly generated samples

Optimal transport modeling for unpaired image restoration tasks.

Following is the experimental setup that is considered for unpaired image restoration.

OTM for image denoising on test C part of CelebA, 64 × 64.

OTM for image colorization on test C part of CelebA, 64 × 64.

OTM for image inpainting on test C part of CelebA, 64 × 64.

Optimal transport modeling for toy examples.

OTM in low-dimensional space, 2D.

Credits

Owner
Litu Rout
I am broadly interested in Optimization, Statistical Learning Theory, Interactive Machine Learning, and Optimal Transport.
Litu Rout
Train a state-of-the-art yolov3 object detector from scratch!

TrainYourOwnYOLO: Building a Custom Object Detector from Scratch This repo let's you train a custom image detector using the state-of-the-art YOLOv3 c

AntonMu 616 Jan 08, 2023
History Aware Multimodal Transformer for Vision-and-Language Navigation

History Aware Multimodal Transformer for Vision-and-Language Navigation This repository is the official implementation of History Aware Multimodal Tra

Shizhe Chen 46 Nov 23, 2022
A generalist algorithm for cell and nucleus segmentation.

Cellpose | A generalist algorithm for cell and nucleus segmentation. Cellpose was written by Carsen Stringer and Marius Pachitariu. To learn about Cel

MouseLand 733 Dec 29, 2022
Source code for PairNorm (ICLR 2020)

PairNorm Official pytorch source code for PairNorm paper (ICLR 2020) This code requires pytorch_geometric=1.3.2 usage For SGC, we use original PairNo

62 Dec 08, 2022
TimeSHAP explains Recurrent Neural Network predictions.

TimeSHAP TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes even

Feedzai 90 Dec 18, 2022
A multi-functional library for full-stack Deep Learning. Simplifies Model Building, API development, and Model Deployment.

chitra What is chitra? chitra (चित्र) is a multi-functional library for full-stack Deep Learning. It simplifies Model Building, API development, and M

Aniket Maurya 210 Dec 21, 2022
A pytorch implementation of Paper "Improved Training of Wasserstein GANs"

WGAN-GP An pytorch implementation of Paper "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, SciPy, Matplotlib A recent NVIDIA GPU

Marvin Cao 1.4k Dec 14, 2022
Official Pytorch implementation of "Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes", CVPR 2022

Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes / 3DCrowdNet News 💪 3DCrowdNet achieves the state-of-the-art accuracy on 3D

Hongsuk Choi 113 Dec 21, 2022
End-To-End Crowdsourcing

End-To-End Crowdsourcing Comparison of traditional crowdsourcing approaches to a state-of-the-art end-to-end crowdsourcing approach LTNet on sentiment

Andreas Koch 1 Mar 06, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 209 Dec 30, 2022
Implementation for Learning to Track with Object Permanence

Learning to Track with Object Permanence A video-based MOT approach capable of tracking through full occlusions: Learning to Track with Object Permane

Toyota Research Institute - Machine Learning 91 Jan 03, 2023
Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment"

DSN-IQA Source code for paper "Deep Superpixel-based Network for Blind Image Quality Assessment" Requirements Python =3.8.0 Pytorch =1.7.1 Usage wit

7 Oct 13, 2022
Hierarchical Cross-modal Talking Face Generation with Dynamic Pixel-wise Loss (ATVGnet)

Hierarchical Cross-modal Talking Face Generation with Dynamic Pixel-wise Loss (ATVGnet) By Lele Chen , Ross K Maddox, Zhiyao Duan, Chenliang Xu. Unive

Lele Chen 218 Dec 27, 2022
code for Grapadora research paper experimentation

Road feature embedding selection method Code for research paper experimentation Abstract Traffic forecasting models rely on data that needs to be sens

Eric López Manibardo 0 May 26, 2022
Piotr - IoT firmware emulation instrumentation for training and research

Piotr: Pythonic IoT exploitation and Research Introduction to Piotr Piotr is an emulation helper for Qemu that provides a convenient way to create, sh

Damien Cauquil 51 Nov 09, 2022
TensorFlow code for the neural network presented in the paper: "Structural Language Models of Code" (ICML'2020)

SLM: Structural Language Models of Code This is an official implementation of the model described in: "Structural Language Models of Code" [PDF] To ap

73 Nov 06, 2022
Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022)

Blockwise Sequential Model Learning Code for 'Blockwise Sequential Model Learning for Partially Observable Reinforcement Learning' (AAAI 2022) For ins

2 Jun 17, 2022
This program generates a random 12 digit/character password (upper and lowercase) and stores it in a file along with your username and app/website.

PasswordGeneratorAndVault This program generates a random 12 digit/character password (upper and lowercase) and stores it in a file along with your us

Chris 1 Feb 26, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Generalized Decision Transformer for Offline Hindsight Information Matching

Generalized Decision Transformer for Offline Hindsight Information Matching [arxiv] If you use this codebase for your research, please cite the paper:

Hiroki Furuta 35 Dec 12, 2022