PyTorch Implementation of DSB for Score Based Generative Modeling. Experiments managed using Hydra.

Overview

Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling

This repository contains the implementation for the paper Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling.

If using this code, please cite the paper:

    @article{de2021diffusion,
              title={Diffusion Schr$\backslash$" odinger Bridge with Applications to Score-Based Generative Modeling},
              author={De Bortoli, Valentin and Thornton, James and Heng, Jeremy and Doucet, Arnaud},
              journal={arXiv preprint arXiv:2106.01357},
              year={2021}
            }

Contributors

  • Valentin De Bortoli
  • James Thornton
  • Jeremy Heng
  • Arnaud Doucet

What is a Schrödinger bridge?

The Schrödinger Bridge (SB) problem is a classical problem appearing in applied mathematics, optimal control and probability; see [1, 2, 3]. In the discrete-time setting, it takes the following (dynamic) form. Consider as reference density p(x0:N) describing the process adding noise to the data. We aim to find p*(x0:N) such that p*(x0) = pdata(x0) and p*(xN) = pprior(xN) and minimize the Kullback-Leibler divergence between p* and p. In this work we introduce Diffusion Schrodinger Bridge (DSB), a new algorithm which uses score-matching approaches [4] to approximate the Iterative Proportional Fitting algorithm, an iterative method to find the solutions of the SB problem. DSB can be seen as a refinement of existing score-based generative modeling methods [5, 6].

Schrodinger bridge

Installation

This project can be installed from its git repository.

  1. Obtain the sources by:

    git clone https://github.com/anon284/schrodinger_bridge.git

or, if git is unavailable, download as a ZIP from GitHub https://github.com/.

  1. Install:

    conda env create -f conda.yaml

    conda activate bridge

  2. Download data examples:

    • CelebA: python data.py --data celeba --data_dir './data/'
    • MNIST: python data.py --data mnist --data_dir './data/'

How to use this code?

  1. Train Networks:
  • 2d: python main.py dataset=2d model=Basic num_steps=20 num_iter=5000
  • mnist python main.py dataset=stackedmnist num_steps=30 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>
  • celeba python main.py dataset=celeba num_steps=50 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>

Checkpoints and sampled images will be saved to a newly created directory. If GPU has insufficient memory, then reduce cache size. 2D dataset should train on CPU. MNIST and CelebA was ran on 2 high-memory V100 GPUs.

References

.. [1] Hans Föllmer Random fields and diffusion processes In: École d'été de Probabilités de Saint-Flour 1985-1987

.. [2] Christian Léonard A survey of the Schrödinger problem and some of its connections with optimal transport In: Discrete & Continuous Dynamical Systems-A 2014

.. [3] Yongxin Chen, Tryphon Georgiou and Michele Pavon Optimal Transport in Systems and Control In: Annual Review of Control, Robotics, and Autonomous Systems 2020

.. [4] Aapo Hyvärinen and Peter Dayan Estimation of non-normalized statistical models by score matching In: Journal of Machine Learning Research 2005

.. [5] Yang Song and Stefano Ermon Generative modeling by estimating gradients of the data distribution In: Advances in Neural Information Processing Systems 2019

.. [6] Jonathan Ho, Ajay Jain and Pieter Abbeel Denoising diffusion probabilistic models In: Advances in Neural Information Processing Systems 2020

Owner
James Thornton
James Thornton
Official code of CVPR 2021's PLOP: Learning without Forgetting for Continual Semantic Segmentation

PLOP: Learning without Forgetting for Continual Semantic Segmentation This repository contains all of our code. It is a modified version of Cermelli e

Arthur Douillard 116 Dec 14, 2022
Pytorch code for "State-only Imitation with Transition Dynamics Mismatch" (ICLR 2020)

This repo contains code for our paper State-only Imitation with Transition Dynamics Mismatch published at ICLR 2020. The code heavily uses the RL mach

20 Sep 08, 2022
Over-the-Air Ensemble Inference with Model Privacy

Over-the-Air Ensemble Inference with Model Privacy This repository contains simulations for our private ensemble inference method. Installation Instal

Selim Firat Yilmaz 1 Jun 29, 2022
Heart Arrhythmia Classification

This program takes and input of an ECG in European Data Format (EDF) and outputs the classification for heartbeats into normal vs different types of arrhythmia . It uses a deep learning model for cla

4 Nov 02, 2022
Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective

Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective Zhengzhuo Xu, Zenghao Chai, Chun Yuan This is the PyTorch implement

Sincere 16 Dec 15, 2022
Official implementation of NeurIPS'21: Implicit SVD for Graph Representation Learning

isvd Official implementation of NeurIPS'21: Implicit SVD for Graph Representation Learning If you find this code useful, you may cite us as: @inprocee

Sami Abu-El-Haija 16 Jan 08, 2023
Gym Threat Defense

Gym Threat Defense The Threat Defense environment is an OpenAI Gym implementation of the environment defined as the toy example in Optimal Defense Pol

Hampus Ramström 5 Dec 08, 2022
meProp: Sparsified Back Propagation for Accelerated Deep Learning (ICML 2017)

meProp The codes were used for the paper meProp: Sparsified Back Propagation for Accelerated Deep Learning with Reduced Overfitting (ICML 2017) [pdf]

LancoPKU 107 Nov 18, 2022
Structural Constraints on Information Content in Human Brain States

Structural Constraints on Information Content in Human Brain States Code accompanying the paper "The information content of brain states is explained

Leon Weninger 3 Sep 07, 2022
Face recognize system

FRS Face_recognize_system This project contains my work that target on solving some problems of FRS: Face detection: Retinaface Face anti-spoofing: Fo

Tran Anh Tuan 4 Nov 18, 2021
Python implementation of Wu et al (2018)'s registration fusion

reg-fusion Projection of a central sulcus probability map using the RF-ANTs approach (right hemisphere shown). This is a Python implementation of Wu e

Dan Gale 26 Nov 12, 2021
SpecAugmentPyTorch - A Pytorch (support batch and channel) implementation of GoogleBrain's SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition

SpecAugment An implementation of SpecAugment for Pytorch How to use Install pytorch, version=1.9.0 (new feature (torch.Tensor.take_along_dim) is used

IMLHF 3 Oct 11, 2022
Code implementation of "Sparsity Probe: Analysis tool for Deep Learning Models"

Sparsity Probe: Analysis tool for Deep Learning Models This repository is a limited implementation of Sparsity Probe: Analysis tool for Deep Learning

3 Jun 09, 2021
Use graph-based analysis to re-classify stocks and to improve Markowitz portfolio optimization

Dynamic Stock Industrial Classification Use graph-based analysis to re-classify stocks and experiment different re-classification methodologies to imp

Sheng Yang 10 Dec 05, 2022
Code for "ATISS: Autoregressive Transformers for Indoor Scene Synthesis", NeurIPS 2021

ATISS: Autoregressive Transformers for Indoor Scene Synthesis This repository contains the code that accompanies our paper ATISS: Autoregressive Trans

138 Dec 22, 2022
A curated list of the latest breakthroughs in AI (in 2021) by release date with a clear video explanation, link to a more in-depth article, and code.

2021: A Year Full of Amazing AI papers- A Review 📌 A curated list of the latest breakthroughs in AI by release date with a clear video explanation, l

Louis-François Bouchard 2.9k Dec 31, 2022
Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Proximal Backpropagation Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient s

Thomas Frerix 40 Dec 17, 2022
Mask-invariant Face Recognition through Template-level Knowledge Distillation

Mask-invariant Face Recognition through Template-level Knowledge Distillation This is the official repository of "Mask-invariant Face Recognition thro

Fadi Boutros 35 Dec 06, 2022
ZeroGen: Efficient Zero-shot Learning via Dataset Generation

ZEROGEN This repository contains the code for our paper “ZeroGen: Efficient Zero

Jiacheng Ye 31 Dec 30, 2022
Visualize Camera's Pose Using Extrinsic Parameter by Plotting Pyramid Model on 3D Space

extrinsic2pyramid Visualize Camera's Pose Using Extrinsic Parameter by Plotting Pyramid Model on 3D Space Intro A very simple and straightforward modu

JEONG HYEONJIN 106 Dec 28, 2022