PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)

Overview

Score-Based Generative Modeling through Stochastic Differential Equations

PWC

This repo contains a PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities (including but not limited to class-conditional generation, inpainting and colorization) to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images (samples below). In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

FFHQ samples

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

JAX version

Please find a JAX implementation here, which additionally supports class-conditional generation with a pre-trained classifier, and resuming an evalution process after pre-emption.

JAX vs. PyTorch

In general, this PyTorch version consumes less memory but runs slower than JAX. Here is a benchmark on training an NCSN++ cont. model with VE SDE. Hardware is 4x Nvidia Tesla V100 GPUs (32GB)

Framework Time (second per step) Memory usage in total (GB)
PyTorch 0.56 20.6
JAX (n_jitted_steps=1) 0.30 29.7
JAX (n_jitted_steps=5) 0.20 74.8

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Stats files for quantitative evaluation

We provide the stats file for CIFAR-10. You can download cifar10_stats.npz and save it to assets/stats/. Check out #5 on how to compute this stats file for new datasets.

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

    Naming conventions of config files: the path of a config file is a combination of the following dimensions:

    • dataset: One of cifar10, celeba, celebahq, celebahq_256, ffhq_256, celebahq, ffhq.
    • model: One of ncsn, ncsnv2, ncsnpp, ddpm, ddpmpp.
    • continuous: train the model with continuously sampled time steps.
  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta . When set to "eval", it can do an arbitrary combination of the following

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID. Prior to evaluation, stats files must have already been downloaded/computed and stored in assets/stats.

    • Compute the log-likelihood on the training or test dataset.

    These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package. For example, to generate samples and evaluate sample quality, supply the --config.eval.enable_sampling flag; to compute log-likelihoods, supply the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset.

How to extend the code

  • New SDEs: inherent the sde_lib.SDE abstract class and implement all abstract methods. The discretize() method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
  • New predictors: inherent the sampling.Predictor abstract class, implement the update_fn abstract method, and register its name with @register_predictor. The new predictor can be directly used in sampling.get_pc_sampler for Predictor-Corrector sampling, and all other controllable generation methods in controllable_generation.py.
  • New correctors: inherent the sampling.Corrector abstract class, implement the update_fn abstract method, and register its name with @register_corrector. The new corrector can be directly used in sampling.get_pc_sampler, and all other controllable generation methods in controllable_generation.py.

Pretrained checkpoints

All checkpoints are provided in this Google drive.

Instructions: You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in our paper's Table 3 (also corresponding to the FID and IS columns in the table below). The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in our paper's Table 2 (also FID(ODE) and NNL (bits/dim) columns in the table below). The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Per Google's policy, we cannot release our original CelebA and CelebA-HQ checkpoints. That said, I have re-trained models on FFHQ 1024px, FFHQ 256px and CelebA-HQ 256px with personal resources, and they achieved similar performance to our internal checkpoints.

Here is a detailed list of checkpoints and their results reported in the paper. FID (ODE) corresponds to the sample quality of black-box ODE solver applied to the probability flow ODE.

Checkpoint path FID IS FID (ODE) NNL (bits/dim)
ve/cifar10_ncsnpp/ 2.45 9.73 - -
ve/cifar10_ncsnpp_continuous/ 2.38 9.83 - -
ve/cifar10_ncsnpp_deep_continuous/ 2.20 9.89 - -
vp/cifar10_ddpm/ 3.24 - 3.37 3.28
vp/cifar10_ddpm_continuous - - 3.69 3.21
vp/cifar10_ddpmpp 2.78 9.64 - -
vp/cifar10_ddpmpp_continuous 2.55 9.58 3.93 3.16
vp/cifar10_ddpmpp_deep_continuous 2.41 9.68 3.08 3.13
subvp/cifar10_ddpm_continuous - - 3.56 3.05
subvp/cifar10_ddpmpp_continuous 2.61 9.56 3.16 3.02
subvp/cifar10_ddpmpp_deep_continuous 2.41 9.57 2.92 2.99
Checkpoint path Samples
ve/bedroom_ncsnpp_continuous bedroom_samples
ve/church_ncsnpp_continuous church_samples
ve/ffhq_1024_ncsnpp_continuous ffhq_1024
ve/ffhq_256_ncsnpp_continuous ffhq_256_samples
ve/celebahq_256_ncsnpp_continuous celebahq_256_samples

Demonstrations and tutorials

Link Description
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (JAX + FLAX)
Open In Colab Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis (PyTorch)
Open In Colab Tutorial of score-based generative models in JAX + FLAX
Open In Colab Tutorial of score-based generative models in PyTorch

Tips

  • When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via config.training.n_jitted_steps. For CIFAR-10, we recommend using config.training.n_jitted_steps=5 when your GPU/TPU has sufficient memory; otherwise we recommend using config.training.n_jitted_steps=1. Our current implementation requires config.training.log_freq to be dividable by n_jitted_steps for logging and checkpointing to work normally.
  • The snr (signal-to-noise ratio) parameter of LangevinCorrector somewhat behaves like a temperature parameter. Larger snr typically results in smoother samples, while smaller snr gives more diverse but lower quality samples. Typical values of snr is 0.05 - 0.2, and it requires tuning to strike the sweet spot.
  • For VE SDEs, we recommend choosing config.model.sigma_max to be the maximum pairwise distance between data samples in the training dataset.

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you:

  • Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
  • Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
  • Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
Owner
Yang Song
PhD Candidate in Stanford AI Lab
Yang Song
HuSpaCy: industrial-strength Hungarian natural language processing

HuSpaCy: Industrial-strength Hungarian NLP HuSpaCy is a spaCy model and a library providing industrial-strength Hungarian language processing faciliti

HuSpaCy 120 Dec 14, 2022
Official PyTorch implementation of Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations

Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations Zhenyu Jiang, Yifeng Zhu, Maxwell Svetlik, Kuan Fang, Yu

UT-Austin Robot Perception and Learning Lab 63 Jan 03, 2023
UltraGCN: An Ultra Simplification of Graph Convolutional Networks for Recommendation

UltraGCN This is our Pytorch implementation for our CIKM 2021 paper: Kelong Mao, Jieming Zhu, Xi Xiao, Biao Lu, Zhaowei Wang, Xiuqiang He. UltraGCN: A

XUEPAI 93 Jan 03, 2023
Face Mask Detection system based on computer vision and deep learning using OpenCV and Tensorflow/Keras

Face Mask Detection Face Mask Detection System built with OpenCV, Keras/TensorFlow using Deep Learning and Computer Vision concepts in order to detect

Chandrika Deb 1.4k Jan 03, 2023
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

BALLAD This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model. Requirements Python3 Pytorch(1.7.

peng gao 42 Nov 26, 2022
Revisiting Temporal Alignment for Video Restoration

Revisiting Temporal Alignment for Video Restoration [arXiv] Kun Zhou, Wenbo Li, Liying Lu, Xiaoguang Han, Jiangbo Lu We provide our results at Google

52 Dec 25, 2022
Repository for tackling Kaggle Ultrasound Nerve Segmentation challenge using Torchnet.

Ultrasound Nerve Segmentation Challenge using Torchnet This repository acts as a starting point for someone who wants to start with the kaggle ultraso

Qure.ai 46 Jul 18, 2022
PyTorch implementation of the WarpedGANSpace: Finding non-linear RBF paths in GAN latent space (ICCV 2021)

Authors official PyTorch implementation of the "WarpedGANSpace: Finding non-linear RBF paths in GAN latent space" [ICCV 2021].

Christos Tzelepis 100 Dec 06, 2022
InterFaceGAN - Interpreting the Latent Space of GANs for Semantic Face Editing

InterFaceGAN - Interpreting the Latent Space of GANs for Semantic Face Editing Figure: High-quality facial attributes editing results with InterFaceGA

GenForce: May Generative Force Be with You 1.3k Jan 09, 2023
DCGAN LSGAN WGAN-GP DRAGAN PyTorch

Recommendation Our GAN based work for facial attribute editing - AttGAN. News 8 April 2019: We re-implement these GANs by Tensorflow 2! The old versio

Zhenliang He 408 Nov 30, 2022
Codebase for the Summary Loop paper at ACL2020

Summary Loop This repository contains the code for ACL2020 paper: The Summary Loop: Learning to Write Abstractive Summaries Without Examples. Training

Canny Lab @ The University of California, Berkeley 44 Nov 04, 2022
[ICLR 2021 Spotlight Oral] "Undistillable: Making A Nasty Teacher That CANNOT teach students", Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang

Undistillable: Making A Nasty Teacher That CANNOT teach students "Undistillable: Making A Nasty Teacher That CANNOT teach students" Haoyu Ma, Tianlong

VITA 71 Dec 28, 2022
Evaluating saliency methods on artificial data with different background types

Evaluating saliency methods on artificial data with different background types This repository contains the relevant code for the MedNeurips 2021 subm

2 Jul 05, 2022
Source code of "Hold me tight! Influence of discriminative features on deep network boundaries"

Hold me tight! Influence of discriminative features on deep network boundaries This is the source code to reproduce the experiments of the NeurIPS 202

EPFL LTS4 19 Dec 10, 2021
A general python framework for visual object tracking and video object segmentation, based on PyTorch

PyTracking A general python framework for visual object tracking and video object segmentation, based on PyTorch. 📣 Two tracking/VOS papers accepted

2.6k Jan 04, 2023
PyTorch code for the NAACL 2021 paper "Improving Generation and Evaluation of Visual Stories via Semantic Consistency"

Improving Generation and Evaluation of Visual Stories via Semantic Consistency PyTorch code for the NAACL 2021 paper "Improving Generation and Evaluat

Adyasha Maharana 28 Dec 08, 2022
《Geo Word Clouds》paper implementation

《Geo Word Clouds》paper implementation

Russellwzr 2 Jan 28, 2022
This repository is an unoffical PyTorch implementation of Medical segmentation in 3D and 2D.

Pytorch Medical Segmentation Read Chinese Introduction:Here! Recent Updates 2021.1.8 The train and test codes are released. 2021.2.6 A bug in dice was

EasyCV-Ellis 618 Dec 27, 2022
[ICCV 2021] Deep Hough Voting for Robust Global Registration

Deep Hough Voting for Robust Global Registration, ICCV, 2021 Project Page | Paper | Video Deep Hough Voting for Robust Global Registration Junha Lee1,

57 Nov 28, 2022
Code for the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness"

DU-VAE This is the pytorch implementation of the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness" Acknowledgement

Dazhong Shen 4 Oct 19, 2022