Codebase for Diffusion Models Beat GANS on Image Synthesis.

Overview

guided-diffusion

This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

This repository is based on openai/improved-diffusion, with modifications for classifier conditioning and architecture improvements.

Download pre-trained models

We have released checkpoints for the main models in the paper. Before using these models, please review the corresponding model card to understand the intended use and limitations of these models.

Here are the download links for each model checkpoint:

Sampling from pre-trained models

To sample from these models, you can use the classifier_sample.py, image_sample.py, and super_res_sample.py scripts. Here, we provide flags for sampling from all of these models. We assume that you have downloaded the relevant model checkpoints into a folder called models/.

For these examples, we will generate 100 samples with batch size 4. Feel free to change these values.

SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 250"

Classifier guidance

Note for these sampling runs that you can set --classifier_scale 0 to sample from the base diffusion model. You may also use the image_sample.py script instead of classifier_sample.py in that case.

  • 64x64 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --dropout 0.1 --image_size 64 --learn_sigma True --noise_schedule cosine --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --use_new_attention_order True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 1.0 --classifier_path models/64x64_classifier.pt --model_path models/64x64_diffusion.pt $SAMPLE_FLAGS
  • 128x128 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 128 --learn_sigma True --noise_schedule linear --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 0.5 --classifier_path models/128x128_classifier.pt --model_path models/128x128_diffusion.pt $SAMPLE_FLAGS
  • 256x256 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 1.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion.pt $SAMPLE_FLAGS
  • 256x256 model (unconditional):
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion.pt $SAMPLE_FLAGS
  • 512x512 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 512 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 4.0 --classifier_path models/512x512_classifier.pt --model_path models/512x512_diffusion.pt $SAMPLE_FLAGS

Upsampling

For these runs, we assume you have some base samples in a file 64_samples.npz or 128_samples.npz for the two respective models.

  • 64 -> 256:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --large_size 256  --small_size 64 --learn_sigma True --noise_schedule linear --num_channels 192 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python super_res_sample.py $MODEL_FLAGS --model_path models/64_256_upsampler.pt --base_samples 64_samples.npz $SAMPLE_FLAGS
  • 128 -> 512:
MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 512 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python super_res_sample.py $MODEL_FLAGS --model_path models/128_512_upsampler.pt $SAMPLE_FLAGS --base_samples 128_samples.npz

LSUN models

These models are class-unconditional and correspond to a single LSUN class. Here, we show how to sample from lsun_bedroom.pt, but the other two LSUN checkpoints should work as well:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python image_sample.py $MODEL_FLAGS --model_path models/lsun_bedroom.pt $SAMPLE_FLAGS

You can sample from lsun_horse_nodropout.pt by changing the dropout flag:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python image_sample.py $MODEL_FLAGS --model_path models/lsun_horse_nodropout.pt $SAMPLE_FLAGS

Note that for these models, the best samples result from using 1000 timesteps:

SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 1000"

Results

This table summarizes our ImageNet results for pure guided diffusion models:

Dataset FID Precision Recall
ImageNet 64x64 2.07 0.74 0.63
ImageNet 128x128 2.97 0.78 0.59
ImageNet 256x256 4.59 0.82 0.52
ImageNet 512x512 7.72 0.87 0.42

This table shows the best results for high resolutions when using upsampling and guidance together:

Dataset FID Precision Recall
ImageNet 256x256 3.94 0.83 0.53
ImageNet 512x512 3.85 0.84 0.53

Finally, here are the unguided results on individual LSUN classes:

Dataset FID Precision Recall
LSUN Bedroom 1.90 0.66 0.51
LSUN Cat 5.57 0.63 0.52
LSUN Horse 2.57 0.71 0.55

Training models

Training diffusion models is described in the parent repository. Training a classifier is similar. We assume you have put training hyperparameters into a TRAIN_FLAGS variable, and classifier hyperparameters into a CLASSIFIER_FLAGS variable. Then you can run:

mpiexec -n N python scripts/classifier_train.py --data_dir path/to/imagenet $TRAIN_FLAGS $CLASSIFIER_FLAGS

Make sure to divide the batch size in TRAIN_FLAGS by the number of MPI processes you are using.

Here are flags for training the 128x128 classifier. You can modify these for training classifiers at other resolutions:

TRAIN_FLAGS="--iterations 300000 --anneal_lr True --batch_size 256 --lr 3e-4 --save_interval 10000 --weight_decay 0.05"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True"

For sampling from a 128x128 classifier-guided model, 25 step DDIM:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --image_size 128 --learn_sigma True --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True --classifier_scale 1.0 --classifier_use_fp16 True"
SAMPLE_FLAGS="--batch_size 4 --num_samples 50000 --timestep_respacing ddim25 --use_ddim True"
mpiexec -n N python scripts/classifier_sample.py \
    --model_path /path/to/model.pt \
    --classifier_path path/to/classifier.pt \
    $MODEL_FLAGS $CLASSIFIER_FLAGS $SAMPLE_FLAGS

To sample for 250 timesteps without DDIM, replace --timestep_respacing ddim25 to --timestep_respacing 250, and replace --use_ddim True with --use_ddim False.

Owner
Katherine Crowson
AI/generative artist.
Katherine Crowson
Generating Anime Images by Implementing Deep Convolutional Generative Adversarial Networks paper

AnimeGAN - Deep Convolutional Generative Adverserial Network PyTorch implementation of DCGAN introduced in the paper: Unsupervised Representation Lear

Rohit Kukreja 23 Jul 21, 2022
Robotics environments

Robotics environments Details and documentation on these robotics environments are available in OpenAI's blog post and the accompanying technical repo

Farama Foundation 121 Dec 28, 2022
Freecodecamp Scientific Computing with Python Certification; Solution for Challenge 2: Time Calculator

Assignment Write a function named add_time that takes in two required parameters and one optional parameter: a start time in the 12-hour clock format

Hellen Namulinda 0 Feb 26, 2022
Multi Agent Reinforcement Learning for ROS in 2D Simulation Environments

IROS21 information To test the code and reproduce the experiments, follow the installation steps in Installation.md. Afterwards, follow the steps in E

11 Oct 29, 2022
An implementation of quantum convolutional neural network with MindQuantum. Huawei, classifying MNIST dataset

关于实现的一点说明 山东大学 2020级 苏博南 www.subonan.com 文件说明 tools.py 这里面主要有两个函数: resize(a, lenb) 这其实是我找同学写的一个小算法hhh。给出一个$28\times 28$的方阵a,返回一个$lenb\times lenb$的方阵。因

ぼっけなす 2 Aug 29, 2022
Video-based open-world segmentation

UVO_Challenge Team Alpes_runner Solutions This is an official repo for our UVO Challenge solutions for Image/Video-based open-world segmentation. Our

Yuming Du 84 Dec 22, 2022
Self-Learning - Books Papers, Courses & more I have to learn soon

Self-Learning This repository is intended to be used for personal use, all rights reserved to respective owners, please cite original authors and ask

Achint Chaudhary 968 Jan 02, 2022
Implementation of "GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings" in PyTorch

PyGAS: Auto-Scaling GNNs in PyG PyGAS is the practical realization of our G NN A uto S cale (GAS) framework, which scales arbitrary message-passing GN

Matthias Fey 139 Dec 25, 2022
PyTorch implementation of "Optimization Planning for 3D ConvNets"

Optimization-Planning-for-3D-ConvNets Code for the ICML 2021 paper: Optimization Planning for 3D ConvNets. Authors: Zhaofan Qiu, Ting Yao, Chong-Wah N

Zhaofan Qiu 2 Jan 12, 2022
Christmas face app for Decathlon xmas coding party!

Christmas Face Application Use this library to create the perfect picture for your christmas cards! Done by Hasib Zunair, Guillaume Brassard and Samue

Hasib Zunair 4 Dec 20, 2021
Tool for live presentations using manim

manim-presentation Tool for live presentations using manim Install pip install manim-presentation opencv-python Usage Use the class Slide as your sce

Federico Galatolo 146 Jan 06, 2023
Code for the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness"

DU-VAE This is the pytorch implementation of the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness" Acknowledgement

Dazhong Shen 4 Oct 19, 2022
code release for USENIX'22 paper `On the Security Risks of AutoML`

This project is a minimized runnable project cut from trojanzoo, which contains more datasets, models, attacks and defenses. This repo will not be mai

Ren Pang 5 Apr 19, 2022
Official pytorch implementation of Rainbow Memory (CVPR 2021)

Rainbow Memory: Continual Learning with a Memory of Diverse Samples

Clova AI Research 91 Dec 17, 2022
Deploy a ML inference service on a budget in less than 10 lines of code.

BudgetML is perfect for practitioners who would like to quickly deploy their models to an endpoint, but not waste a lot of time, money, and effort trying to figure out how to do this end-to-end.

1.3k Dec 25, 2022
A Python package for faster, safer, and simpler ML processes

Bender 🤖 A Python package for faster, safer, and simpler ML processes. Why use bender? Bender will make your machine learning processes, faster, safe

Otovo 6 Dec 13, 2022
Code for AA-RMVSNet: Adaptive Aggregation Recurrent Multi-view Stereo Network (ICCV 2021).

AA-RMVSNet Code for AA-RMVSNet: Adaptive Aggregation Recurrent Multi-view Stereo Network (ICCV 2021) in PyTorch. paper link: arXiv | CVF Change Log Ju

Qingtian Zhu 97 Dec 30, 2022
An implementation of Fastformer: Additive Attention Can Be All You Need in TensorFlow

Fast Transformer This repo implements Fastformer: Additive Attention Can Be All You Need by Wu et al. in TensorFlow. Fast Transformer is a Transformer

Rishit Dagli 139 Dec 28, 2022
Exploring the link between uncertainty estimates obtained via "exact" Bayesian inference and out-of-distribution (OOD) detection.

Uncertainty-based OOD detection Exploring the link between uncertainty estimates obtained by "exact" Bayesian inference and out-of-distribution (OOD)

Christian Henning 1 Nov 05, 2022
Improving Convolutional Networks via Attention Transfer (ICLR 2017)

Attention Transfer PyTorch code for "Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Tran

Sergey Zagoruyko 1.4k Dec 23, 2022