Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Related tags

Deep LearningSB-FBSDE
Overview

Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory [ICLR 2022]

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that generalizes score-based models to fully nonlinear forward and backward diffusions.

SB-FBSDE result

This repo is co-maintained by Guan-Horng Liu and Tianrong Chen. Contact us if you have any questions! If you find this library useful, please cite ⬇️

@inproceedings{chen2022likelihood,
  title={Likelihood Training of Schr{\"o}dinger Bridge using Forward-Backward SDEs Theory},
  author={Chen, Tianrong and Liu, Guan-Horng and Theodorou, Evangelos A},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1). First, install the dependencies with Anaconda and activate the environment sb-fbsde with

conda env create --file requirements.yaml python=3
conda activate sb-fbsde

Training

python main.py \
  --problem-name <PROBLEM_NAME> \
  --forward-net <FORWARD_NET> \
  --backward-net <BACKWARD_NET> \
  --num-FID-sample <NUM_FID_SAMPLE> \ # add this flag only for CIFAR-10
  --dir <DIR> \
  --log-tb 

To train an SB-FBSDE from scratch, run the above command, where

  • PROBLEM_NAME is the dataset. We support gmm (2D mixture of Gaussian), checkerboard (2D toy dataset), mnist, celebA32, celebA64, cifar10.
  • FORWARD_NET & BACKWARD_NET are the deep networks for forward and backward drifts. We support Unet, nscnpp, and a toy network for 2D datasets.
  • NUM_FID_SAMPLE is the number of generated images used to evaluate FID locally. We recommend 10000 for training CIFAR-10. Note that this requires first downloading the FID statistics checkpoint.
  • DIR specifies where the results (e.g. snapshots during training) shall be stored.
  • log-tb enables logging with Tensorboard.

Additionally, use --load to restore previous checkpoint or pre-trained model. For training CIFAR-10 specifically, we support loading the pre-trained NCSN++ as the backward policy of the first SB training stage (this is because the first SB training stage can degenerate to denoising score matching under proper initialization; see more details in Appendix D of our paper).

Other configurations are detailed in options.py. The default configurations for each dataset are provided in the configs folder.

Evaluating the CIFAR-10 Checkpoint

To evaluate SB-FBSDE on CIFAR-10 (we achieve FID 3.01 and NLL 2.96), create a folder checkpoint then download the model checkpoint and FID statistics checkpoint either from Google Drive or through the following commands.

mkdir checkpoint && cd checkpoint

# FID stat checkpoint. This's needed whenever we
# need to compute FID during training or sampling.
gdown --id 1Tm_5nbUYKJiAtz2Rr_ARUY3KIFYxXQQD 

# SB-FBSDE model checkpoint for reproducing results in the paper.
gdown --id 1Kcy2IeecFK79yZDmnky36k4PR2yGpjyg 

After downloading the checkpoints, run the following commands for computing either NLL or FID. Set the batch size --samp-bs properly depending on your hardware.

# compute NLL
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-NLL --samp-bs <BS>
# compute FID
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-FID --samp-bs <BS> --num-FID-sample 50000 --use-corrector --snr 0.15
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Data & Code for ACCENTOR Adding Chit-Chat to Enhance Task-Oriented Dialogues

ACCENTOR: Adding Chit-Chat to Enhance Task-Oriented Dialogues Overview ACCENTOR consists of the human-annotated chit-chat additions to the 23.8K dialo

Facebook Research 69 Dec 29, 2022
Certified Patch Robustness via Smoothed Vision Transformers

Certified Patch Robustness via Smoothed Vision Transformers This repository contains the code for replicating the results of our paper: Certified Patc

Madry Lab 35 Dec 14, 2022
Benchmark datasets, data loaders, and evaluators for graph machine learning

Overview The Open Graph Benchmark (OGB) is a collection of benchmark datasets, data loaders, and evaluators for graph machine learning. Datasets cover

1.5k Jan 05, 2023
Reproduced Code for Image Forgery Detection papers.

Image Forgery Detection With over 4.5 billion active internet users, the amount of multimedia content being shared every day has surpassed everyone’s

Umar Masud 15 Dec 06, 2022
An unreferenced image captioning metric (ACL-21)

UMIC This repository provides an unferenced image captioning metric from our ACL 2021 paper UMIC: An Unreferenced Metric for Image Captioning via Cont

hwanheelee 14 Nov 20, 2022
Learning to Segment Instances in Videos with Spatial Propagation Network

Learning to Segment Instances in Videos with Spatial Propagation Network This paper is available at the 2017 DAVIS Challenge website. Check our result

Jingchun Cheng 145 Sep 28, 2022
A Gura parser implementation for Python

Gura Python parser This repository contains the implementation of a Gura (compliant with version 1.0.0) format parser in Python. Installation pip inst

Gura Config Lang 19 Jan 25, 2022
Udacity Suse Cloud Native Foundations Scholarship Course Walkthrough

SUSE Cloud Native Foundations Scholarship Udacity is collaborating with SUSE, a global leader in true open source solutions, to empower developers and

Shivansh Srivastava 34 Oct 18, 2022
PipeTransformer: Automated Elastic Pipelining for Distributed Training of Large-scale Models

PipeTransformer: Automated Elastic Pipelining for Distributed Training of Large-scale Models This repository is the official implementation of the fol

DistributedML 41 Dec 06, 2022
Monify: an Expense tracker Program implemented in a Graphical User Interface that allows users to keep track of their expenses

💳 MONIFY (EXPENSE TRACKER PRO) 💳 Description Monify is an Expense tracker Program implemented in a Graphical User Interface allows users to add inco

Moyosore Weke 1 Dec 14, 2021
Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training"

Saliency Guided Training Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training" by Aya Abdelsalam Ismail, Hector Cor

8 Sep 22, 2022
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
Code for WSDM 2022 paper, Contrastive Learning for Representation Degeneration Problem in Sequential Recommendation.

DuoRec Code for WSDM 2022 paper, Contrastive Learning for Representation Degeneration Problem in Sequential Recommendation. Usage Download datasets fr

Qrh 46 Dec 19, 2022
Source-to-Source Debuggable Derivatives in Pure Python

Tangent Tangent is a new, free, and open-source Python library for automatic differentiation. Existing libraries implement automatic differentiation b

Google 2.2k Jan 01, 2023
PyTorch implementation of deep GRAph Contrastive rEpresentation learning (GRACE).

GRACE The official PyTorch implementation of deep GRAph Contrastive rEpresentation learning (GRACE). For a thorough resource collection of self-superv

Big Data and Multi-modal Computing Group, CRIPAC 186 Dec 27, 2022
Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models

Python scripts to detect faces using Python with the BlazeFace Tensorflow Lite models. Tested on Windows 10, Tensorflow 2.4.0 (Python 3.8).

Ibai Gorordo 46 Nov 17, 2022
Repo for "TableParser: Automatic Table Parsing with Weak Supervision from Spreadsheets" at [email protected]

TableParser Repo for "TableParser: Automatic Table Parsing with Weak Supervision from Spreadsheets" at DS3 Lab 11 Dec 13, 2022

Distilled coarse part of LoFTR adapted for compatibility with TensorRT and embedded divices

Coarse LoFTR TRT Google Colab demo notebook This project provides a deep learning model for the Local Feature Matching for two images that can be used

Kirill 46 Dec 24, 2022
Code and Resources for the Transformer Encoder Reasoning Network (TERN)

Transformer Encoder Reasoning Network Code for the cross-modal visual-linguistic retrieval method from "Transformer Reasoning Network for Image-Text M

Nicola Messina 53 Dec 30, 2022
SlotRefine: A Fast Non-Autoregressive Model forJoint Intent Detection and Slot Filling

SlotRefine: A Fast Non-Autoregressive Model for Joint Intent Detection and Slot Filling Reference Main paper to be cited (Di Wu et al., 2020) @article

Moore 34 Nov 03, 2022