A minimalist implementation of score-based diffusion model

Overview

sdeflow-light

This is a minimalist codebase for training score-based diffusion models (supporting MNIST and CIFAR-10) used in the following paper

"A Variational Perspective on Diffusion-Based Generative Models and Score Matching" by Chin-Wei Huang, Jae Hyun Lim and Aaron Courville [arXiv]

Also see the concurrent work by Yang Song & Conor Durkan where they used the same idea to obtain state-of-the-art likelihood estimates.

Experiments on Swissroll

Here's a Colab notebook which contains an example for training a model on the Swissroll dataset.

Open In Colab

In this notebook, you'll see how to train the model using score matching loss, how to evaluate the ELBO of the plug-in reverse SDE, and how to sample from it. It also includes a snippet to sample from a family of plug-in reverse SDEs (parameterized by λ) mentioned in Appendix C of the paper.

Below are the trajectories of λ=0 (the reverse SDE used in Song et al.) and λ=1 (equivalent ODE) when we plug in the learned score / drift function. This corresponds to Figure 5 of the paper. drawing drawing

Experiments on MNIST and CIFAR-10

This repository contains one main training loop (train_img.py). The model is trained to minimize the denoising score matching loss by calling the .dsm(x) loss function, and evaluated using the following ELBO, by calling .elbo_random_t_slice(x)

score-elbo

where the divergence (sum of the diagonal entries of the Jacobian) is estimated using the Hutchinson trace estimator.

It's a minimalist codebase in the sense that we do not use fancy optimizer (we only use Adam with the default setup) or learning rate scheduling. We use the modified U-net architecture from Denoising Diffusion Probabilistic Models by Jonathan Ho.

A key difference from Song et al. is that instead of parameterizing the score function s, here we parameterize the drift term a (where they are related by a=gs and g is the diffusion coefficient). That is, a is the U-net.

Parameterization: Our original generative & inference SDEs are

  • dX = mu dt + sigma dBt
  • dY = (-mu + sigma*a) ds + sigma dBs

We reparameterize it as

  • dX = (ga - f) dt + g dBt
  • dY = f ds + g dBs

by letting mu = ga - f, and sigma = g. (since f and g are fixed, we only have one degree of freedom, which is a). Alternatively, one can parameterize s (e.g. using the U-net), and just let a=gs.

How it works

Here's an example command line for running an experiment

python train_img.py --dataroot=[DATAROOT] --saveroot=[SAVEROOT] --expname=[EXPNAME] \
    --dataset=cifar --print_every=2000 --sample_every=2000 --checkpoint_every=2000 --num_steps=1000 \
    --batch_size=128 --lr=0.0001 --num_iterations=100000 --real=True --debias=False

Setting --debias to be False uses uniform sampling for the time variable, whereas setting it to be True uses a non-uniform sampling strategy to debias the gradient estimate described in the paper. Below are the bits-per-dim and the corresponding standard error of the test set recorded during training (orange for --debias=True and blue for --debias=False).

drawing drawing

Here are some samples (debiased on the right)

drawing drawing

It takes about 14 hrs to finish 100k iterations on a V100 GPU.

Owner
Chin-Wei Huang
Chin-Wei Huang
Real-time VIBE: Frame by Frame Inference of VIBE (Video Inference for Human Body Pose and Shape Estimation)

Real-time VIBE Inference VIBE frame-by-frame. Overview This is a frame-by-frame inference fork of VIBE at [https://github.com/mkocabas/VIBE]. Usage: i

23 Jul 02, 2022
Complete* list of autonomous driving related datasets

AD Datasets Complete* and curated list of autonomous driving related datasets Contributing Contributions are very welcome! To add or update a dataset:

Daniel Bogdoll 13 Dec 19, 2022
Public Implementation of ChIRo from "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations"

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations This directory contains the model architectures and experimental

35 Dec 05, 2022
Discover hidden deepweb pages

DeepWeb Scapper Att: Demo version An simple script to scrappe deepweb to find pages. Will return if any of those exists and will save on a file. You s

Héber Júlio 77 Oct 02, 2022
Jingju baseline - A baseline model of our project of Beijing opera script generation

Jingju Baseline It is a baseline of our project about Beijing opera script gener

midon 1 Jan 14, 2022
CHERRY is a python library for predicting the interactions between viral and prokaryotic genomes

CHERRY is a python library for predicting the interactions between viral and prokaryotic genomes. CHERRY is based on a deep learning model, which consists of a graph convolutional encoder and a link

Kenneth Shang 12 Dec 15, 2022
This repository is to support contributions for tools for the Project CodeNet dataset hosted in DAX

The goal of Project CodeNet is to provide the AI-for-Code research community with a large scale, diverse, and high quality curated dataset to drive innovation in AI techniques.

International Business Machines 1.2k Jan 04, 2023
Efficient Two-Step Networks for Temporal Action Segmentation (Neurocomputing 2021)

Efficient Two-Step Networks for Temporal Action Segmentation This repository provides a PyTorch implementation of the paper Efficient Two-Step Network

8 Apr 16, 2022
DC3: A Learning Method for Optimization with Hard Constraints

DC3: A learning method for optimization with hard constraints This repository is by Priya L. Donti, David Rolnick, and J. Zico Kolter and contains the

CMU Locus Lab 57 Dec 26, 2022
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

EleutherAI 96 Dec 21, 2022
Image-Adaptive YOLO for Object Detection in Adverse Weather Conditions

Image-Adaptive YOLO for Object Detection in Adverse Weather Conditions Accepted by AAAI 2022 [arxiv] Wenyu Liu, Gaofeng Ren, Runsheng Yu, Shi Guo, Jia

liuwenyu 245 Dec 16, 2022
TorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision

TorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{you2019torchcv, author = {Ansheng You and Xiangtai Li and Zhen Zhu a

Donny You 2.2k Jan 06, 2023
[IJCAI-2021] A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

DataFree A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation" Authors: Gongfa

ZJU-VIPA 47 Jan 09, 2023
The final project for "Applying AI to Wearable Device Data" course from "AI for Healthcare" - Udacity.

Motion Compensated Pulse Rate Estimation Overview This project has 2 main parts. Develop a Pulse Rate Algorithm on the given training data. Then Test

Omar Laham 2 Oct 25, 2022
Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Database

Python cx_Oracle Notebooks, 2022 The repository contains Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Da

Christopher Jones 13 Dec 15, 2022
This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports"

Introduction: X-Ray Report Generation This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports". O

no name 36 Dec 16, 2022
Easy to use Audio Tagging in PyTorch

Audio Classification, Tagging & Sound Event Detection in PyTorch Progress: Fine-tune on audio classification Fine-tune on audio tagging Fine-tune on s

sithu3 15 Dec 22, 2022
A python package simulating the quasi-2D pseudospin-1/2 Gross-Pitaevskii equation with NVIDIA GPU acceleration.

A python package simulating the quasi-2D pseudospin-1/2 Gross-Pitaevskii equation with NVIDIA GPU acceleration. Introduction spinor-gpe is high-level,

2 Sep 20, 2022
MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition

MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition Paper: MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition accepted fo

64 Dec 18, 2022
Topic Discovery via Latent Space Clustering of Pretrained Language Model Representations

TopClus The source code used for Topic Discovery via Latent Space Clustering of Pretrained Language Model Representations, published in WWW 2022. Requ

Yu Meng 63 Dec 18, 2022