Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Related tags

Deep LearningSB-FBSDE
Overview

Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory [ICLR 2022]

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that generalizes score-based models to fully nonlinear forward and backward diffusions.

SB-FBSDE result

This repo is co-maintained by Guan-Horng Liu and Tianrong Chen. Contact us if you have any questions! If you find this library useful, please cite ⬇️

@inproceedings{chen2022likelihood,
  title={Likelihood Training of Schr{\"o}dinger Bridge using Forward-Backward SDEs Theory},
  author={Chen, Tianrong and Liu, Guan-Horng and Theodorou, Evangelos A},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1). First, install the dependencies with Anaconda and activate the environment sb-fbsde with

conda env create --file requirements.yaml python=3
conda activate sb-fbsde

Training

python main.py \
  --problem-name <PROBLEM_NAME> \
  --forward-net <FORWARD_NET> \
  --backward-net <BACKWARD_NET> \
  --num-FID-sample <NUM_FID_SAMPLE> \ # add this flag only for CIFAR-10
  --dir <DIR> \
  --log-tb 

To train an SB-FBSDE from scratch, run the above command, where

  • PROBLEM_NAME is the dataset. We support gmm (2D mixture of Gaussian), checkerboard (2D toy dataset), mnist, celebA32, celebA64, cifar10.
  • FORWARD_NET & BACKWARD_NET are the deep networks for forward and backward drifts. We support Unet, nscnpp, and a toy network for 2D datasets.
  • NUM_FID_SAMPLE is the number of generated images used to evaluate FID locally. We recommend 10000 for training CIFAR-10. Note that this requires first downloading the FID statistics checkpoint.
  • DIR specifies where the results (e.g. snapshots during training) shall be stored.
  • log-tb enables logging with Tensorboard.

Additionally, use --load to restore previous checkpoint or pre-trained model. For training CIFAR-10 specifically, we support loading the pre-trained NCSN++ as the backward policy of the first SB training stage (this is because the first SB training stage can degenerate to denoising score matching under proper initialization; see more details in Appendix D of our paper).

Other configurations are detailed in options.py. The default configurations for each dataset are provided in the configs folder.

Evaluating the CIFAR-10 Checkpoint

To evaluate SB-FBSDE on CIFAR-10 (we achieve FID 3.01 and NLL 2.96), create a folder checkpoint then download the model checkpoint and FID statistics checkpoint either from Google Drive or through the following commands.

mkdir checkpoint && cd checkpoint

# FID stat checkpoint. This's needed whenever we
# need to compute FID during training or sampling.
gdown --id 1Tm_5nbUYKJiAtz2Rr_ARUY3KIFYxXQQD 

# SB-FBSDE model checkpoint for reproducing results in the paper.
gdown --id 1Kcy2IeecFK79yZDmnky36k4PR2yGpjyg 

After downloading the checkpoints, run the following commands for computing either NLL or FID. Set the batch size --samp-bs properly depending on your hardware.

# compute NLL
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-NLL --samp-bs <BS>
# compute FID
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-FID --samp-bs <BS> --num-FID-sample 50000 --use-corrector --snr 0.15
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Official Code Implementation of the paper : XAI for Transformers: Better Explanations through Conservative Propagation

Official Code Implementation of The Paper : XAI for Transformers: Better Explanations through Conservative Propagation For the SST-2 and IMDB expermin

Ameen Ali 23 Dec 30, 2022
Official Pytorch implementation of MixMo framework

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks Official PyTorch implementation of the MixMo framework | paper | docs Alexandr

79 Nov 07, 2022
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
Spontaneous Facial Micro Expression Recognition using 3D Spatio-Temporal Convolutional Neural Networks

Spontaneous Facial Micro Expression Recognition using 3D Spatio-Temporal Convolutional Neural Networks Abstract Facial expression recognition in video

Bogireddy Sai Prasanna Teja Reddy 103 Dec 29, 2022
Canonical Capsules: Unsupervised Capsules in Canonical Pose (NeurIPS 2021)

Canonical Capsules: Unsupervised Capsules in Canonical Pose (NeurIPS 2021) Introduction This is the official repository for the PyTorch implementation

165 Dec 07, 2022
Semi-Supervised Learning with Ladder Networks in Keras. Get 98% test accuracy on MNIST with just 100 labeled examples !

Semi-Supervised Learning with Ladder Networks in Keras This is an implementation of Ladder Network in Keras. Ladder network is a model for semi-superv

Divam Gupta 101 Sep 07, 2022
Automatic Differentiation Multipole Moment Molecular Forcefield

Automatic Differentiation Multipole Moment Molecular Forcefield Performance notes On a single gpu, using waterbox_31ang.pdb example from MPIDplugin wh

4 Jan 07, 2022
AAAI 2022: Stationary diffusion state neural estimation

Stationary Diffusion State Neural Estimation Although many graph-based clustering methods attempt to model the stationary diffusion state in their obj

绽琨 33 Nov 24, 2022
Mining-the-Social-Web-3rd-Edition - The official online compendium for Mining the Social Web, 3rd Edition (O'Reilly, 2018)

Mining the Social Web, 3rd Edition The official code repository for Mining the Social Web, 3rd Edition (O'Reilly, 2019). The book is available from Am

Mikhail Klassen 838 Jan 01, 2023
Statistical and Algorithmic Investing Strategies for Everyone

Eiten - Algorithmic Investing Strategies for Everyone Eiten is an open source toolkit by Tradytics that implements various statistical and algorithmic

Tradytics 2.5k Jan 02, 2023
HIVE: Evaluating the Human Interpretability of Visual Explanations

HIVE: Evaluating the Human Interpretability of Visual Explanations Project Page | Paper This repo provides the code for HIVE, a human evaluation frame

Princeton Visual AI Lab 16 Dec 13, 2022
Interpolation-based reduced-order models

Interpolation-reduced-order-models Interpolation-based reduced-order models High-fidelity computational fluid dynamics (CFD) solutions are time consum

Donovan Blais 1 Jan 10, 2022
Fastquant - Backtest and optimize your trading strategies with only 3 lines of code!

fastquant 🤓 Bringing backtesting to the mainstream fastquant allows you to easily backtest investment strategies with as few as 3 lines of python cod

Lorenzo Ampil 1k Dec 29, 2022
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty

AugMix Introduction We propose AugMix, a data processing technique that mixes augmented images and enforces consistent embeddings of the augmented ima

Google Research 876 Dec 17, 2022
Time Delayed NN implemented in pytorch

Pytorch Time Delayed NN Time Delayed NN implemented in PyTorch. Usage kernels = [(1, 25), (2, 50), (3, 75), (4, 100), (5, 125), (6, 150)] tdnn = TDNN

Daniil Gavrilov 79 Aug 04, 2022
source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

International Business Machines 71 Nov 15, 2022
A curated list of awesome deep long-tailed learning resources.

A curated list of awesome deep long-tailed learning resources.

vanint 210 Dec 25, 2022
Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV

Realtime Face Anti-Spoofing Detection 🤖 Realtime Face Anti Spoofing Detection with Face Detector to detect real and fake faces Please star this repo

Prem Kumar 86 Aug 03, 2022
Run containerized, rootless applications with podman

Why? restrict scope of file system access run any application without root privileges creates usable "Desktop applications" to integrate into your nor

119 Dec 27, 2022
Spherical Confidence Learning for Face Recognition, accepted to CVPR2021.

Sphere Confidence Face (SCF) This repository contains the PyTorch implementation of Sphere Confidence Face (SCF) proposed in the CVPR2021 paper: Shen

Maths 70 Dec 09, 2022