Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)

Overview

Maximum Likelihood Training of Score-Based Diffusion Models

This repo contains the official implementation for the paper Maximum Likelihood Training of Score-Based Diffusion Models

by Yang Song*, Conor Durkan*, Iain Murray, and Stefano Ermon. Published in NeurIPS 2021 (spotlight).


We prove the connection between the Kullback–Leibler divergence and the weighted combination of score matching losses used for training score-based generative models. Our results can be viewed as a generalization of both the de Bruijn identity in information theory and the evidence lower bound in variational inference.

Our theoretical results enable ScoreFlow, a continuous normalizing flow model trained with a variational objective, which is much more efficient than neural ODEs. We report the state-of-the-art likelihood on CIFAR-10 and ImageNet 32x32 among all flow models, achieving comparable performance to cutting-edge autoregressive models.

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Stats files for quantitative evaluation

We provide stats files for computing FID and Inception scores for CIFAR-10 and ImageNet 32x32. You can find cifar10_stats.npz and imagenet32_stats.npz under the directory assets/stats in our Google drive. Download them and save to assets/stats/ in the code repo.

Usage

Train and evaluate our models through main.py. Here are some common options:

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval|train_deq>: Running mode: train or eval or training the Flow++ variational dequantization model
  --workdir: Working directory
  • config is the path to the config file. Our config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

    Naming conventions of config files: the name of a config file contains the following attributes:

    • dataset: Either cifar10 or imagenet32
    • model: Either ddpmpp_continuous or ddpmpp_deep_continuous
  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for supporting pre-emption recovery, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval" or "train_deq". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta . When set to "eval", it can do the following:

    • Compute the log-likelihood on the training or test dataset.

    • Compute the lower bound of the log-likelihood on the training or test dataset.

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID. Prior to evaluation, stats files must have already been downloaded/computed and stored in assets/stats.

      When set to "train_deq", it trains a Flow++ variational dequantization model to bridge the gap of likelihoods on continuous and discrete images. Recommended if you want to compete with generative models trained on discrete images, such as VAEs and autoregressive models. train_deq mode also supports pre-emption recovery.

These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package.

Configurations for training

To turn on likelihood weighting, set --config.training.likelihood_weighting. To additionally turn on importance sampling for variance reduction, use --config.training.likelihood_weighting. To train a separate Flow++ variational dequantizer, you need to first finish training a score-based model, then use --mode=train_deq.

Configurations for evaluation

To generate samples and evaluate sample quality, use the --config.eval.enable_sampling flag; to compute log-likelihoods, use the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset. Turn on --config.eval.bound to evaluate the variational bound for the log-likelihood. Enable --config.eval.dequantizer to use variational dequantization for likelihood computation. --config.eval.num_repeats configures the number of repetitions across the dataset (more can reduce the variance of the likelihoods; default to 5).

Pretrained checkpoints

All checkpoints are provided in this Google drive.

Folder structure:

  • assets: contains cifar10_stats.npz and imagenet32_stats.npz. Necessary for computing FID and Inception scores.
  • <cifar10|imagenet32>_(deep)_<vp|subvp>_(likelihood)_(iw)_(flip). Here the part enclosed in () is optional. deep in the name specifies whether the score model is a deeper architecture (ddpmpp_deep_continuous). likelihood specifies whether the model was trained with likelihood weighting. iw specifies whether the model was trained with importance sampling for variance reduction. flip shows whether the model was trained with horizontal flip for data augmentation. Each folder has the following two subfolders:
    • checkpoints: contains the last checkpoint for the score-based model.
    • flowpp_dequantizer/checkpoints: contains the last checkpoint for the Flow++ variational dequantization model.

References

If you find the code useful for your research, please consider citing

@inproceedings{song2021maximum,
  title={Maximum Likelihood Training of Score-Based Diffusion Models},
  author={Song, Yang and Durkan, Conor and Murray, Iain and Ermon, Stefano},
  booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
  year={2021}
}

This work is built upon some previous papers which might also interest you:

  • Yang Song and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems, 2019.
  • Yang Song and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems, 2020.
  • Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations". Proceedings of the 9th International Conference on Learning Representations, 2021.
Owner
Yang Song
PhD Candidate in Stanford AI Lab
Yang Song
SeqFormer: a Frustratingly Simple Model for Video Instance Segmentation

SeqFormer: a Frustratingly Simple Model for Video Instance Segmentation SeqFormer SeqFormer: a Frustratingly Simple Model for Video Instance Segmentat

Junfeng Wu 298 Dec 22, 2022
Genetic feature selection module for scikit-learn

sklearn-genetic Genetic feature selection module for scikit-learn Genetic algorithms mimic the process of natural selection to search for optimal valu

Manuel Calzolari 260 Dec 14, 2022
This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

4 Aug 02, 2022
Dynamic Token Normalization Improves Vision Transformers

Dynamic Token Normalization Improves Vision Transformers This is the PyTorch implementation of the paper Dynamic Token Normalization Improves Vision T

Wenqi Shao 20 Oct 09, 2022
UDP++ (ECCVW 2020 Oral), (Winner of COCO 2020 Keypoint Challenge).

UDP-Pose This is the pytorch implementation for UDP++, which won the Fisrt place in COCO Keypoint Challenge at ECCV 2020 Workshop. Top-Down Results on

20 Jul 29, 2022
So-ViT: Mind Visual Tokens for Vision Transformer

So-ViT: Mind Visual Tokens for Vision Transformer        Introduction This repository contains the source code under PyTorch framework and models trai

Jiangtao Xie 44 Nov 24, 2022
HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation Official PyTorch Implementation

: We present a novel, real-time, semantic segmentation network in which the encoder both encodes and generates the parameters (weights) of the decoder. Furthermore, to allow maximal adaptivity, the w

Yuval Nirkin 182 Dec 14, 2022
robomimic: A Modular Framework for Robot Learning from Demonstration

robomimic [Homepage]   [Documentation]   [Study Paper]   [Study Website]   [ARISE Initiative] Latest Updates [08/09/2021] v0.1.0: Initial code and pap

ARISE Initiative 178 Jan 05, 2023
This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes.

Polygon-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes. Section I. Description The codes a

xinzelee 226 Jan 05, 2023
Deep learning image registration library for PyTorch

TorchIR: Pytorch Image Registration TorchIR is a image registration library for deep learning image registration (DLIR). I have integrated several ide

Bob de Vos 40 Dec 16, 2022
Explaining in Style: Training a GAN to explain a classifier in StyleSpace

Explaining in Style: Official TensorFlow Colab Explaining in Style: Training a GAN to explain a classifier in StyleSpace Oran Lang, Yossi Gandelsman,

Google 197 Nov 08, 2022
Minimalistic PyTorch training loop

Backbone for PyTorch training loop Will try to keep it minimalistic. pip install back from back import Bone Features Progress bar Checkpoints saving/l

Kashin 4 Jan 16, 2020
PyTorch implementation for SDEdit: Image Synthesis and Editing with Stochastic Differential Equations

SDEdit: Image Synthesis and Editing with Stochastic Differential Equations Project | Paper | Colab PyTorch implementation of SDEdit: Image Synthesis a

536 Jan 05, 2023
Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020.

RegNet Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020. Paper | Official Implementation RegNet offer a very

Vishal R 2 Feb 11, 2022
Volumetric Correspondence Networks for Optical Flow, NeurIPS 2019.

VCN: Volumetric correspondence networks for optical flow [project website] Requirements python 3.6 pytorch 1.1.0-1.3.0 pytorch correlation module (opt

Gengshan Yang 144 Dec 06, 2022
Pytorch implementation of winner from VQA Chllange Workshop in CVPR'17

2017 VQA Challenge Winner (CVPR'17 Workshop) pytorch implementation of Tips and Tricks for Visual Question Answering: Learnings from the 2017 Challeng

Mark Dong 166 Dec 11, 2022
A python library to build Model Trees with Linear Models at the leaves.

A python library to build Model Trees with Linear Models at the leaves.

Marco Cerliani 212 Dec 30, 2022
PyTorch code for the "Deep Neural Networks with Box Convolutions" paper

Box Convolution Layer for ConvNets Single-box-conv network (from `examples/mnist.py`) learns patterns on MNIST What This Is This is a PyTorch implemen

Egor Burkov 515 Dec 18, 2022
python debugger and anti-vm that checks if you're in a virtual machine or if someones trying to debug your file

Anti-Debug was made by Love ❌ code ✅ 🎉 ・What it checks for ・ Kills tools that can be used to debug your file ・ Exits if ran in vm (supports different

Rdimo 31 Aug 09, 2022