Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Related tags

Deep LearningSB-FBSDE
Overview

Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory [ICLR 2022]

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that generalizes score-based models to fully nonlinear forward and backward diffusions.

SB-FBSDE result

This repo is co-maintained by Guan-Horng Liu and Tianrong Chen. Contact us if you have any questions! If you find this library useful, please cite ⬇️

@inproceedings{chen2022likelihood,
  title={Likelihood Training of Schr{\"o}dinger Bridge using Forward-Backward SDEs Theory},
  author={Chen, Tianrong and Liu, Guan-Horng and Theodorou, Evangelos A},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1). First, install the dependencies with Anaconda and activate the environment sb-fbsde with

conda env create --file requirements.yaml python=3
conda activate sb-fbsde

Training

python main.py \
  --problem-name <PROBLEM_NAME> \
  --forward-net <FORWARD_NET> \
  --backward-net <BACKWARD_NET> \
  --num-FID-sample <NUM_FID_SAMPLE> \ # add this flag only for CIFAR-10
  --dir <DIR> \
  --log-tb 

To train an SB-FBSDE from scratch, run the above command, where

  • PROBLEM_NAME is the dataset. We support gmm (2D mixture of Gaussian), checkerboard (2D toy dataset), mnist, celebA32, celebA64, cifar10.
  • FORWARD_NET & BACKWARD_NET are the deep networks for forward and backward drifts. We support Unet, nscnpp, and a toy network for 2D datasets.
  • NUM_FID_SAMPLE is the number of generated images used to evaluate FID locally. We recommend 10000 for training CIFAR-10. Note that this requires first downloading the FID statistics checkpoint.
  • DIR specifies where the results (e.g. snapshots during training) shall be stored.
  • log-tb enables logging with Tensorboard.

Additionally, use --load to restore previous checkpoint or pre-trained model. For training CIFAR-10 specifically, we support loading the pre-trained NCSN++ as the backward policy of the first SB training stage (this is because the first SB training stage can degenerate to denoising score matching under proper initialization; see more details in Appendix D of our paper).

Other configurations are detailed in options.py. The default configurations for each dataset are provided in the configs folder.

Evaluating the CIFAR-10 Checkpoint

To evaluate SB-FBSDE on CIFAR-10 (we achieve FID 3.01 and NLL 2.96), create a folder checkpoint then download the model checkpoint and FID statistics checkpoint either from Google Drive or through the following commands.

mkdir checkpoint && cd checkpoint

# FID stat checkpoint. This's needed whenever we
# need to compute FID during training or sampling.
gdown --id 1Tm_5nbUYKJiAtz2Rr_ARUY3KIFYxXQQD 

# SB-FBSDE model checkpoint for reproducing results in the paper.
gdown --id 1Kcy2IeecFK79yZDmnky36k4PR2yGpjyg 

After downloading the checkpoints, run the following commands for computing either NLL or FID. Set the batch size --samp-bs properly depending on your hardware.

# compute NLL
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-NLL --samp-bs <BS>
# compute FID
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-FID --samp-bs <BS> --num-FID-sample 50000 --use-corrector --snr 0.15
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
Joint Gaussian Graphical Model Estimation: A Survey

Joint Gaussian Graphical Model Estimation: A Survey Test Models Fused graphical lasso [1] Group graphical lasso [1] Graphical lasso [1] Doubly joint s

Koyejo Lab 1 Aug 10, 2022
Ipython notebook presentations for getting starting with basic programming, statistics and machine learning techniques

Data Science 45-min Intros Every week*, our data science team @Gnip (aka @TwitterBoulder) gets together for about 50 minutes to learn something. While

Scott Hendrickson 1.6k Dec 31, 2022
Medical Image Segmentation using Squeeze-and-Expansion Transformers

Medical Image Segmentation using Squeeze-and-Expansion Transformers Introduction This repository contains the code of the IJCAI'2021 paper 'Medical Im

askerlee 172 Dec 20, 2022
[NeurIPS'21 Spotlight] PyTorch code for our paper "Aligned Structured Sparsity Learning for Efficient Image Super-Resolution"

ASSL This repository is for a new network pruning method (Aligned Structured Sparsity Learning, ASSL) for efficient single image super-resolution (SR)

Huan Wang 47 Nov 28, 2022
Forecasting with Gradient Boosted Time Series Decomposition

ThymeBoost ThymeBoost combines time series decomposition with gradient boosting to provide a flexible mix-and-match time series framework for spicy fo

131 Jan 08, 2023
GDR-Net: Geometry-Guided Direct Regression Network for Monocular 6D Object Pose Estimation. (CVPR 2021)

GDR-Net This repo provides the PyTorch implementation of the work: Gu Wang, Fabian Manhardt, Federico Tombari, Xiangyang Ji. GDR-Net: Geometry-Guided

169 Jan 07, 2023
PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners

Masked Autoencoders: A PyTorch Implementation This is a PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners: @

Meta Research 4.8k Jan 04, 2023
Code release for Local Light Field Fusion at SIGGRAPH 2019

Local Light Field Fusion Project | Video | Paper Tensorflow implementation for novel view synthesis from sparse input images. Local Light Field Fusion

1.1k Dec 27, 2022
Xview3 solution - XView3 challenge, 2nd place solution

Xview3, 2nd place solution https://iuu.xview.us/ test split aggregate score publ

Selim Seferbekov 24 Nov 23, 2022
Scikit-event-correlation - Event Correlation and Forecasting over High Dimensional Streaming Sensor Data algorithms

scikit-event-correlation Event Correlation and Changing Detection Algorithm Theo

Intellia ICT 5 Oct 30, 2022
Scripts used to make and evaluate OpenAlex's concept tagging model

openalex-concept-tagging This repository contains all of the code for getting the concept tagger up and running. To learn more about where this model

OurResearch 18 Dec 09, 2022
Project Tugas Besar pertama Pengenalan Komputasi Institut Teknologi Bandung

Vending_Machine_(Mesin_Penjual_Minuman) Project Tugas Besar pertama Pengenalan Komputasi Institut Teknologi Bandung Raw Sketch untuk Essay Ringkasan P

QueenLy 1 Nov 08, 2021
All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

Daniel Bourke 3.4k Jan 07, 2023
MINERVA: An out-of-the-box GUI tool for offline deep reinforcement learning

MINERVA is an out-of-the-box GUI tool for offline deep reinforcement learning, designed for everyone including non-programmers to do reinforcement learning as a tool.

Takuma Seno 80 Nov 06, 2022
LUKE -- Language Understanding with Knowledge-based Embeddings

LUKE (Language Understanding with Knowledge-based Embeddings) is a new pre-trained contextualized representation of words and entities based on transf

Studio Ousia 587 Dec 30, 2022
This repository is an implementation of paper : Improving the Training of Graph Neural Networks with Consistency Regularization

CRGNN Paper : Improving the Training of Graph Neural Networks with Consistency Regularization Environments Implementing environment: GeForce RTX™ 3090

THUDM 28 Dec 09, 2022
BLEURT is a metric for Natural Language Generation based on transfer learning.

BLEURT: a Transfer Learning-Based Metric for Natural Language Generation BLEURT is an evaluation metric for Natural Language Generation. It takes a pa

Google Research 492 Jan 05, 2023
Fast, flexible and fun neural networks.

Brainstorm Discontinuation Notice Brainstorm is no longer being maintained, so we recommend using one of the many other,available frameworks, such as

IDSIA 1.3k Nov 21, 2022
YouRefIt: Embodied Reference Understanding with Language and Gesture

YouRefIt: Embodied Reference Understanding with Language and Gesture YouRefIt: Embodied Reference Understanding with Language and Gesture by Yixin Che

16 Jul 11, 2022