A set of tests for evaluating large-scale algorithms for Wasserstein-2 transport maps computation.

Overview

Continuous Wasserstein-2 Benchmark

This is the official Python implementation of the NeurIPS 2021 paper Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark (paper on arxiv) by Alexander Korotin, Lingxiao Li, Aude Genevay, Justin Solomon, Alexander Filippov and Evgeny Burnaev.

The repository contains a set of continuous benchmark measures for testing optimal transport solvers for quadratic cost (Wasserstein-2 distance), the code for optimal transport solvers and their evaluation.

Citation

@article{korotin2021neural,
  title={Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark},
  author={Korotin, Alexander and Li, Lingxiao and Genevay, Aude and Solomon, Justin and Filippov, Alexander and Burnaev, Evgeny},
  journal={arXiv preprint arXiv:2106.01954},
  year={2021}
}

Pre-requisites

The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with

torch==1.3.0 torchvision==0.4.1

The code might not run as intended in newer torch versions.

Related repositories

Loading Benchmark Pairs

from src import map_benchmark as mbm

# Load benchmark pair for dimension 16 (2, 4, ..., 256)
benchmark = mbm.Mix3ToMix10Benchmark(16)
# OR load 'Early' images benchmark pair ('Early', 'Mid', 'Late')
# benchmark = mbm.CelebA64Benchmark('Early')

# Sample 32 random points from the benchmark measures
X = benchmark.input_sampler.sample(32)
Y = benchmark.output_sampler.sample(32)

# Compute the true forward map for points X
X.requires_grad_(True)
Y_true = benchmark.map_fwd(X, nograd=True)

Repository structure

All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/). Auxilary source code is moved to .py modules (src/). Continuous benchmark pairs are stored as .pt checkpoints (benchmarks/).

Evaluation of Existing Solvers

We provide all the code to evaluate existing dual OT solvers on our benchmark pairs. The qualitative results are shown below. For quantitative results, see the paper.

Testing Existing Solvers On High-Dimensional Benchmarks

  • notebooks/MM_test_hd_benchmark.ipynb -- testing [MM], [MMv2] solvers and their reversed versions
  • notebooks/MMv1_test_hd_benchmark.ipynb -- testing [MMv1] solver
  • notebooks/MM-B_test_hd_benchmark.ipynb -- testing [MM-B] solver
  • notebooks/W2_test_hd_benchmark.ipynb -- testing [W2] solver and its reversed version
  • notebooks/QC_test_hd_benchmark.ipynb -- testing [QC] solver
  • notebooks/LS_test_hd_benchmark.ipynb -- testing [LS] solver

Testing Existing Solvers On Images Benchmark Pairs (CelebA 64x64 Aligned Faces)

  • notebooks/MM_test_images_benchmark.ipynb -- testing [MM] solver and its reversed version
  • notebooks/W2_test_images_benchmark.ipynb -- testing [W2]
  • notebooks/MM-B_test_images_benchmark.ipynb -- testing [MM-B] solver
  • notebooks/QC_test_images_benchmark.ipynb -- testing [QC] solver

[LS], [MMv2], [MMv1] solvers are not considered in this experiment.

Generative Modeling by Using Existing Solvers to Compute Loss

Warning: training may take several days before achieving reasonable FID scores!

  • notebooks/MM_test_image_generation.ipynb -- generative modeling by [MM] solver or its reversed version
  • notebooks/W2_test_image_generation.ipynb -- generative modeling by [W2] solver

For [QC] solver we used the code from the official WGAN-QC repo.

Training Benchmark Pairs From Scratch

This code is provided for completeness and is not intended to be used to retrain existing benchmark pairs, but might be used as the base to train new pairs on new datasets. High-dimensional benchmak pairs can be trained from scratch. Training images benchmark pairs requires generator network checkpoints. We used WGAN-QC model to provide such checkpoints.

  • notebooks/W2_train_hd_benchmark.ipynb -- training high-dimensional benchmark bairs by [W2] solver
  • notebooks/W2_train_images_benchmark.ipynb -- training images benchmark bairs by [W2] solver

Credits

Owner
Alexander
PhD Student (Computer Science) at Skolkovo University of Science and Technology (Moscow, Russia)
Alexander
Code for the SIGGRAPH 2021 paper "Consistent Depth of Moving Objects in Video".

Consistent Depth of Moving Objects in Video This repository contains training code for the SIGGRAPH 2021 paper "Consistent Depth of Moving Objects in

Google 203 Jan 05, 2023
Dynamical Wasserstein Barycenters for Time Series Modeling

Dynamical Wasserstein Barycenters for Time Series Modeling This is the code related for the Dynamical Wasserstein Barycenter model published in Neurip

8 Sep 09, 2022
Stochastic gradient descent with model building

Stochastic Model Building (SMB) This repository includes a new fast and robust stochastic optimization algorithm for training deep learning models. Th

S. Ilker Birbil 22 Jan 19, 2022
A Python Library for Graph Outlier Detection (Anomaly Detection)

PyGOD is a Python library for graph outlier detection (anomaly detection). This exciting yet challenging field has many key applications, e.g., detect

PyGOD Team 757 Jan 04, 2023
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
Bidimensional Leaderboards: Generate and Evaluate Language Hand in Hand

Bidimensional Leaderboards: Generate and Evaluate Language Hand in Hand Introduction We propose a generalization of leaderboards, bidimensional leader

4 Dec 03, 2022
This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametric Head Model (CVPR 2022)".

HeadNeRF: A Real-time NeRF-based Parametric Head Model This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametr

294 Jan 01, 2023
RL and distillation in CARLA using a factorized world model

World on Rails Learning to drive from a world on rails Dian Chen, Vladlen Koltun, Philipp Krähenbühl, arXiv techical report (arXiv 2105.00636) This re

Dian Chen 131 Dec 16, 2022
code for Multi-scale Matching Networks for Semantic Correspondence, ICCV

MMNet This repo is the official implementation of ICCV 2021 paper "Multi-scale Matching Networks for Semantic Correspondence.". Pre-requisite conda cr

joey zhao 25 Dec 12, 2022
The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation"

SD-AANet The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation" [arxiv] Overview confi

cv516Buaa 9 Nov 07, 2022
Code of paper Interact, Embed, and EnlargE (IEEE): Boosting Modality-specific Representations for Multi-Modal Person Re-identification.

Interact, Embed, and EnlargE (IEEE): Boosting Modality-specific Representations for Multi-Modal Person Re-identification We provide the codes for repr

12 Dec 12, 2022
Code for AutoNL on ImageNet (CVPR2020)

Neural Architecture Search for Lightweight Non-Local Networks This repository contains the code for CVPR 2020 paper Neural Architecture Search for Lig

Yingwei Li 104 Aug 31, 2022
Angora is a mutation-based fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without symbolic execution.

Angora Angora is a mutation-based coverage guided fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without s

833 Jan 07, 2023
Some methods for comparing network representations in deep learning and neuroscience.

Generalized Shape Metrics on Neural Representations In neuroscience and in deep learning, quantifying the (dis)similarity of neural representations ac

Alex Williams 45 Dec 27, 2022
A simple python module to generate anchor (aka default/prior) boxes for object detection tasks.

PyBx WIP A simple python module to generate anchor (aka default/prior) boxes for object detection tasks. Calculated anchor boxes are returned as ndarr

thatgeeman 4 Dec 15, 2022
PyTorch implementation of HDN(Homography Decomposition Networks) for planar object tracking

Homography Decomposition Networks for Planar Object Tracking This project is the offical PyTorch implementation of HDN(Homography Decomposition Networ

CaptainHook 48 Dec 15, 2022
Code for the paper "JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design"

JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design This repository contains code for the paper: JA

Aspuru-Guzik group repo 55 Nov 29, 2022
A clean implementation based on AlphaZero for any game in any framework + tutorial + Othello/Gobang/TicTacToe/Connect4 and more

Alpha Zero General (any game, any framework!) A simplified, highly flexible, commented and (hopefully) easy to understand implementation of self-play

Surag Nair 3.1k Jan 05, 2023
SigOpt wrappers for scikit-learn methods

SigOpt + scikit-learn Interfacing This package implements useful interfaces and wrappers for using SigOpt and scikit-learn together Getting Started In

SigOpt 73 Sep 30, 2022
Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding

2D-TAN (Optimized) Introduction This is an optimized re-implementation repository for AAAI'2020 paper: Learning 2D Temporal Localization Networks for

Joya Chen 112 Dec 31, 2022