Official code for Score-Based Generative Modeling through Stochastic Differential Equations

Overview

Score-Based Generative Modeling through Stochastic Differential Equations

This repo contains the official implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images. In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models all in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir . When set to "eval", it can do an arbitrary combination of the following

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID.

    • Compute the log-likelihood on the training or test dataset.

    These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package. For example, to generate samples and evaluate sample quality, supply the --config.eval.enable_sampling flag; to compute log-likelihoods, supply the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset.

How to extend the code

  • New SDEs: inherent the sde_lib.SDE abstract class and implement all abstract methods. The discretize() method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
  • New predictors: inherent the sampling.Predictor abstract class, implement the update_fn abstract method, and register its name with @register_predictor. The new predictor can be directly used in sampling.get_pc_sampler for Predictor-Corrector sampling, and all other controllable generation methods in controllable_generation.py.
  • New correctors: inherent the sampling.Corrector abstract class, implement the update_fn abstract method, and register its name with @register_corrector. The new corrector can be directly used in sampling.get_pc_sampler, and all other controllable generation methods in controllable_generation.py.

Pretrained checkpoints

Link: https://drive.google.com/drive/folders/10pQygNzF7hOOLwP3q8GiNxSnFRpArUxQ?usp=sharing

You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in Table 3. The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in Table 2. The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Demonstrations and tutorials

  • Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis

Open In Colab

  • Tutorial of score-based generative models in JAX + FLAX

Open In Colab

  • Tutorial of score-based generative models in PyTorch

Open In Colab

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you:

  • Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
  • Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
  • Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
Owner
Yang Song
PhD Candidate in Stanford AI Lab
Yang Song
PyTorch implementation for NED. It can be used to manipulate the facial emotions of actors in videos based on emotion labels or reference styles.

Neural Emotion Director (NED) - Official Pytorch Implementation Example video of facial emotion manipulation while retaining the original mouth motion

Foivos Paraperas 89 Dec 23, 2022
Code for our CVPR2021 paper coordinate attention

Coordinate Attention for Efficient Mobile Network Design (preprint) This repository is a PyTorch implementation of our coordinate attention (will appe

Qibin (Andrew) Hou 726 Jan 05, 2023
A tool to estimate time varying instantaneous reproduction number during epidemics

EpiEstim A tool to estimate time varying instantaneous reproduction number during epidemics. It is described in the following paper: @article{Cori2013

MRC Centre for Global Infectious Disease Analysis 78 Dec 19, 2022
๐Ÿ… Top 5% in ์ œ2ํšŒ ์—ฐ๊ตฌ๊ฐœ๋ฐœํŠน๊ตฌ ์ธ๊ณต์ง€๋Šฅ ๊ฒฝ์ง„๋Œ€ํšŒ AI SPARK ์ฑŒ๋ฆฐ์ง€

AI_SPARK_CHALLENG_Object_Detection ์ œ2ํšŒ ์—ฐ๊ตฌ๊ฐœ๋ฐœํŠน๊ตฌ ์ธ๊ณต์ง€๋Šฅ ๊ฒฝ์ง„๋Œ€ํšŒ AI SPARK ์ฑŒ๋ฆฐ์ง€ ๐Ÿ… Top 5% in mAP(0.75) (443๋ช… ์ค‘ 13๋“ฑ, mAP: 0.98116) ๋Œ€ํšŒ ์„ค๋ช… Edge ํ™˜๊ฒฝ์—์„œ์˜ ๊ฐ€์ถ• Object Dete

3 Sep 19, 2022
[BMVC'21] Official PyTorch Implementation of Grounded Situation Recognition with Transformers

Grounded Situation Recognition with Transformers Paper | Model Checkpoint This is the official PyTorch implementation of Grounded Situation Recognitio

Junhyeong Cho 18 Jul 19, 2022
Implementation of association rules mining algorithms (Apriori|FPGrowth) using python.

Association Rules Mining Using Python Implementation of association rules mining algorithms (Apriori|FPGrowth) using python. As a part of hw1 code in

Pre 2 Nov 10, 2021
RCT-ART is an NLP pipeline built with spaCy for converting clinical trial result sentences into tables through jointly extracting intervention, outcome and outcome measure entities and their relations.

Randomised controlled trial abstract result tabulator RCT-ART is an NLP pipeline built with spaCy for converting clinical trial result sentences into

2 Sep 16, 2022
ImageNet-CoG is a benchmark for concept generalization. It provides a full evaluation framework for pre-trained visual representations which measure how well they generalize to unseen concepts.

The ImageNet-CoG Benchmark Project Website Paper (arXiv) Code repository for the ImageNet-CoG Benchmark introduced in the paper "Concept Generalizatio

NAVER 23 Oct 09, 2022
Contextualized Perturbation for Textual Adversarial Attack, NAACL 2021

Contextualized Perturbation for Textual Adversarial Attack Introduction This is a PyTorch implementation of Contextualized Perturbation for Textual Ad

cookielee77 30 Jan 01, 2023
The object detection pipeline is based on Ultralytics YOLOv5

AYOLOv2 The main goal of this repository is to rewrite the object detection pipeline with a better code structure for better portability and adaptabil

153 Dec 22, 2022
Code for "Reconstructing 3D Human Pose by Watching Humans in the Mirror", CVPR 2021 oral

Reconstructing 3D Human Pose by Watching Humans in the Mirror Qi Fang*, Qing Shuai*, Junting Dong, Hujun Bao, Xiaowei Zhou CVPR 2021 Oral The videos a

ZJU3DV 178 Dec 13, 2022
This repo is a PyTorch implementation for Paper "Unsupervised Learning for Cuboid Shape Abstraction via Joint Segmentation from Point Clouds"

Unsupervised Learning for Cuboid Shape Abstraction via Joint Segmentation from Point Clouds This repository is a PyTorch implementation for paper: Uns

Kaizhi Yang 42 Dec 09, 2022
Official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right"

Surface Form Competition This is the official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right" We p

Peter West 46 Dec 23, 2022
code for our ECCV-2020 paper: Self-supervised Video Representation Learning by Pace Prediction

Video_Pace This repository contains the code for the following paper: Jiangliu Wang, Jianbo Jiao and Yunhui Liu, "Self-Supervised Video Representation

Jiangliu Wang 95 Dec 14, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 19, 2021
Bio-Computing Platform Featuring Large-Scale Representation Learning and Multi-Task Deep Learning โ€œ่žบๆ—‹ๆกจโ€็”Ÿ็‰ฉ่ฎก็ฎ—ๅทฅๅ…ท้›†

English | ็ฎ€ไฝ“ไธญๆ–‡ Latest News 2021.10.25 Paper "Docking-based Virtual Screening with Multi-Task Learning" is accepted by BIBM 2021. 2021.07.29 PaddleHeli

633 Jan 04, 2023
Interactive Image Generation via Generative Adversarial Networks

iGAN: Interactive Image Generation via Generative Adversarial Networks Project | Youtube | Paper Recent projects: [pix2pix]: Torch implementation for

Jun-Yan Zhu 3.9k Dec 23, 2022
The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

AI2 111 Dec 18, 2022
A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets

HOW TO USE THIS PROJECT A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets Based on DeepLabCut toolbox, we run wit

1 Jan 10, 2022
An open source app to help calm you down when needed.

By: Seanpm2001, Et; Al. Top README.md Read this article in a different language Sorted by: A-Z Sorting options unavailable ( af Afrikaans Afrikaans |

Sean P. Myrick V19.1.7.2 2 Oct 24, 2022