A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

Overview

Minimal implementation of diffusion models

A minimal implementation of diffusion models with the goal to democratize the use of synthetic data from these models.

Check out the experimental results section for quantitative numbers on quality of synthetic data and FAQs for a broader discussion. We experiments with nine commonly used datasets, and released all assets, including models and synthetic data for each of them.

Requirements: pip install scipy opencv-python. We assume torch and torchvision are already installed.

Structure

main.py  - Train or sample from a diffusion model.
unets.py - UNet based network architecture for diffusion model.
data.py  - Common datasets and their metadata.
──  scripts
     └── train.sh  - Training scripts for all datasets.
     └── sample.sh - Sampling scripts for all datasets.

Training

Use the following command to train the diffusion model on four gpus.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --epochs 500

We provide the exact script used for training in ./scripts/train.sh.

Sampling

We reuse main.py for sampling but with the --sampling-only only flag. Use the following command to sample 50K images from a pretrained diffusion model.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --sampling-only --sampling-steps 250 \
  --num-sampled-images 50000 --pretrained-ckpt path_to_pretrained_model

We provide the exact script used for sampling in ./scripts/sample.sh.

How useful is synthetic data from diffusion models? 🤔

Takeaway: Across all datasets, training only on synthetic data suffice to achieve a competitive classification score on real data.

Goal: Our goal is to not only measure photo-realism of synthetic images but also measure how well synthetic images cover the data distribution, i.e., how diverse is synthetic data. Note that a generative model, commonly GANs, can generate high-quality images, but still fail to generate diverse images.

Choice of datasets: We use nine commonly used datasets in image recognition. The goal was to multiple datasets was to capture enough diversity in terms of the number of samples, the number of classes, and coarse vs fine-grained classification. In addition, by using a common setup across datasets, we can test the success of diffusion models without any assumptions about the dataset.

Diffusion model: For each dataset, we train a class-conditional diffusion model. We choose a modest size network and train it for a limited number of hours on a 4xA4000 cluster, as highlighted by the training time in the table below. Next, we sample 50,000 synthetic images from the diffusion model.

Metric to measure synthetic data quality: We train a ResNet50 classifier on only real images and another one on only synthetic images and measure their accuracy on the validation set of real images. This metric is also referred to as classification accuracy score and it provides us a way to measure both quality and diversity of synthetic data in a unified manner across datasets.

Released assets for each dataset: Pre-trained Diffusion models, 50,000 synthetic images for each dataset, and downstream clasifiers trained with real-only or synthetic-only dataset.

Table 1: Training images and classes refer to the number of training images and the number of classes in the dataset. Training time refers to the time taken to train the diffusion model. Real only is the test set accuracy of ResNet-50 model trained on only real training images. Synthetic accuracy is the test accuracy of the ResNet-50 model trained on only 50K synthetic images.

Dataset Training images Classes Training time (hours) Real only Synthetic only
MNIST 60,000 10 2.1 99.6 99.0
MNIST-M 60,000 10 5.3 99.3 97.3
CIFAR-10 50,000 10 10.7 93.8 87.3
Skin Cancer* 33126 2 19.1 69.7 64.1
AFHQ 14630 3 8.6 97.9 98.7
CelebA 109036 4 12.8 90.1 88.9
Standford Cars 8144 196 7.4 33.7 76.6
Oxford Flowers 2040 102 6.0 29.3 76.3
Traffic signs 39252 43 8.3 96.6 96.1

* Due to heavy class imbalance, we use AUROC to measure classification performance.

Note: Except for CIFAR10, MNIST, MNIST-M, and GTSRB, we use 64x64 image resolution for all datasets. The key reason to use a lower resolution was to reduce the computational resources needed to train the diffusion model.

Discussion: Across most datasets training only on synthetic data achieves competitive performance with training on real data. It shows that the synthetic data 1) has high-quality images, otherwise the model wouldn't have learned much from it 2) high coverage of distribution, otherwise, the model trained on synthetic data won't do well on the whole test set. Even more, the synthetic dataset has a unique advantage: we can easily generate a very large amount of it. This difference is clearly visible for the low-data regime (flowers and cars dataset), where training on synthetic data (50K images) achieves much better performance than training on real data, which has less than 10K images. A more principled investigation of sample complexity, i.e., performance vs number-of-synthetic-images is available in one of my previous papers (fig. 9).

FAQs

Q. Why use diffusion models?
A. This question is super broad and has multiple answers. 1) They are super easy to train. Unlike GANs, there are no training instabilities in the optimization process. 2) The mode coverage of the diffusion models is excellent where at the same time the generated images are quite photorealistic. 3) The training pipeline is also consistent across datasets, i.e., no assumption about the data. For all datasets above, the only parameter we changed was the amount of training time.

Q. Is synthetic data from diffusion models much different from other generative models, in particular GANs?
A. As mentioned in the previous answer, synthetic data from diffusion models have much higher coverage than GANs, while having a similar image quality. Check out the this previous paper by Prafulla Dhariwal and Alex Nichol where they provide extensive results supporting this claim. In the regime of robust training, you can find a more quantitive comparison of diffusion models with multiple GANs in one of my previous papers.

Q. Why classification accuracy on some datasets is so low (e.g., flowers), even when training with real data?
A. Due to many reasons, current classification numbers aren't meant to be competitive with state-of-the-art. 1) We don't tune any hyperparameters across datasets. For each dataset, we train a ResNet50 model with 0.1 learning rate, 1e-4 weight decay, 0.9 momentum, and cosine learning rate decay. 2) Instead of full resolution (commonly 224x224), we use low-resolution images (64x64), which makes classification harder.

Q. Using only synthetic data, how to further improve the test accuracy on real data?
A. Diffusion models benefit tremendously from scaling of the training setup. One can do so by increasing the network width (base_width) and training the network for more epochs (2-4x).

References

This implementation was originally motivated by the original implmentation of diffusion models by Jonathan Ho. I followed the recent PyTorch implementation by OpenAI for common design choices in diffusion models.

The experiments to test out the potential of synthetic data from diffusion models are inspired by one of my previous work. We found that using synthetic data from the diffusion model alone surpasses benefits from multiple algorithmic innovations in robust training, which is one of the simple yet extremely hard problems to solve for neural networks. The next step is to repeat the Table-1 experiments, but this time with robust training.

Visualizing real and synthetic images

For each data, we plot real images on the left and synthetic images on the right. Each row corresponds to a unique class while classes for real and synthetic data are identical.

Light     Dark

MNIST

Light     Dark

MNIST-M

Light     Dark

CIFAR-10

Light     Dark

GTSRB

Light     Dark

Celeb-A

Light     Dark

AFHQ

Light     Dark

Cars

Light     Dark

Flowers

Light     Dark

Melanoma (Skin cancer)

Light     Dark

Note: Real images for each dataset follow the same license as their respective dataset.

Owner
Vikash Sehwag
PhD candidate at Princeton University. Interested in problems at the intersection of Security, Privacy, and Machine leanring.
Vikash Sehwag
Cl datasets - PyTorch image dataloaders and utility functions to load datasets for supervised continual learning

Continual learning datasets Introduction This repository contains PyTorch image

berjaoui 5 Aug 28, 2022
Code for the KDD 2021 paper 'Filtration Curves for Graph Representation'

Filtration Curves for Graph Representation This repository provides the code from the KDD'21 paper Filtration Curves for Graph Representation. Depende

Machine Learning and Computational Biology Lab 16 Oct 16, 2022
NER for Indian languages

CL-NERIL: A Cross-Lingual Model for NER in Indian Languages Code for the paper - https://arxiv.org/abs/2111.11815 Setup Setup a virtual environment Th

Akshara P 0 Nov 24, 2021
Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun

ARAE Code for the paper "Adversarially Regularized Autoencoders (ICML 2018)" by Zhao, Kim, Zhang, Rush and LeCun https://arxiv.org/abs/1706.04223 Disc

Junbo (Jake) Zhao 399 Jan 02, 2023
[CVPR2021] De-rendering the World's Revolutionary Artefacts

De-rendering the World's Revolutionary Artefacts Project Page | Video | Paper In CVPR 2021 Shangzhe Wu1,4, Ameesh Makadia4, Jiajun Wu2, Noah Snavely4,

49 Nov 06, 2022
K Closest Points and Maximum Clique Pruning for Efficient and Effective 3D Laser Scan Matching (To appear in RA-L 2022)

KCP The official implementation of KCP: k Closest Points and Maximum Clique Pruning for Efficient and Effective 3D Laser Scan Matching, accepted for p

Yu-Kai Lin 109 Dec 14, 2022
Blender Add-On for slicing meshes with planes

MeshSlicer Blender Add-On for slicing meshes with multiple overlapping planes at once. This is a simple Blender addon to slice a silmple mesh with mul

52 Dec 12, 2022
Power Core Simulator!

Power Core Simulator Power Core Simulator is a simulator based off the Roblox game "Pinewood Builders Computer Core". In this simulator, you can choos

BananaJeans 1 Nov 13, 2021
Contains code for the paper "Vision Transformers are Robust Learners".

Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin

Sayak Paul 103 Jan 05, 2023
Official repository of "BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment"

BasicVSR_PlusPlus (CVPR 2022) [Paper] [Project Page] [Code] This is the official repository for BasicVSR++. Please feel free to raise issue related to

Kelvin C.K. Chan 227 Jan 01, 2023
This is an official implementation of the CVPR2022 paper "Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots".

Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots Blind2Unblind Citing Blind2Unblind @inproceedings{wang2022blind2unblind, tit

demonsjin 58 Dec 06, 2022
A tiny, pedagogical neural network library with a pytorch-like API.

candl A tiny, pedagogical implementation of a neural network library with a pytorch-like API. The primary use of this library is for education. Use th

Sri Pranav 3 May 23, 2022
Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy

Creating a Linear Program Solver by Implementing the Simplex Method in Python with NumPy Simplex Algorithm is a popular algorithm for linear programmi

Reda BELHAJ 2 Oct 12, 2022
Mahadi-Now - This Is Pakistani Just Now Login Tools

PAKISTANI JUST NOW LOGIN TOOLS Install apt update apt upgrade apt install python

MAHADI HASAN AFRIDI 19 Apr 06, 2022
A curated list of awesome resources related to Semantic Search🔎 and Semantic Similarity tasks.

A curated list of awesome resources related to Semantic Search🔎 and Semantic Similarity tasks.

224 Jan 04, 2023
Multi-resolution SeqMatch based long-term Place Recognition

MRS-SLAM for long-term place recognition In this work, we imply an multi-resolution sambling based visual place recognition method. This work is based

METASLAM 6 Dec 06, 2022
This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR

This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR,which is an open-source toolbox based on PyTorch. The overall architecture will be sh

Jianquan Ye 82 Nov 17, 2022
Python tools for 3D face: 3DMM, Mesh processing(transform, camera, light, render), 3D face representations.

face3d: Python tools for processing 3D face Introduction This project implements some basic functions related to 3D faces. You can use this to process

Yao Feng 2.3k Dec 30, 2022
Marvis is Mastouri's Jarvis version of the AI-powered Python personal assistant.

Marvis v1.0 Marvis is Mastouri's Jarvis version of the AI-powered Python personal assistant. About M.A.R.V.I.S. J.A.R.V.I.S. is a fictional character

Reda Mastouri 1 Dec 29, 2021
HarDNeXt: Official HarDNeXt repository

HarDNeXt-Pytorch HarDNeXt: A Stage Receptive Field and Connectivity Aware Convolution Neural Network HarDNeXt-MSEG for Medical Image Segmentation in 0

5 May 26, 2022