PyTorch reimplementation of Diffusion Models

Overview

PyTorch pretrained Diffusion Models

A PyTorch reimplementation of Denoising Diffusion Probabilistic Models with checkpoints converted from the author's TensorFlow implementation.

Quickstart

Running

pip install -e git+https://github.com/pesser/pytorch_diffusion.git#egg=pytorch_diffusion
pytorch_diffusion_demo

will start a Streamlit demo. It is recommended to run the demo with a GPU available.

demo

Usage

Diffusion models with pretrained weights for cifar10, lsun-bedroom, lsun_cat or lsun_church can be loaded as follows:

from pytorch_diffusion import Diffusion

diffusion = Diffusion.from_pretrained("lsun_church")
samples = diffusion.denoise(4)
diffusion.save(samples, "lsun_church_sample_{:02}.png")

Prefix the name with ema_ to load the averaged weights that produce better results. The U-Net model used for denoising is available via diffusion.model and can also be instantiated on its own:

from pytorch_diffusion import Model

model = Model(resolution=32,
              in_channels=3,
              out_ch=3,
              ch=128,
              ch_mult=(1,2,2,2),
              num_res_blocks=2,
              attn_resolutions=(16,),
              dropout=0.1)

This configuration example corresponds to the model used on CIFAR-10.

Producing samples

If you installed directly from github, you can find the cloned repository in <venv path>/src/pytorch_diffusion for virtual environments, and <cwd>/src/pytorch_diffusion for global installs. There, you can run

python pytorch_diffusion/diffusion.py <name> <bs> <nb>

where <name> is one of cifar10, lsun-bedroom, lsun_cat, lsun_church, or one of these names prefixed with ema_, <bs> is the batch size and <nb> the number of batches. This will produce samples from the PyTorch models and save them to results/<name>/.

Results

Evaluating 50k samples with torch-fidelity gives

Dataset EMA Framework Model FID
CIFAR10 Train no PyTorch cifar10 12.13775
TensorFlow tf_cifar10 12.30003
yes PyTorch ema_cifar10 3.21213
TensorFlow tf_ema_cifar10 3.245872
CIFAR10 Validation no PyTorch cifar10 14.30163
TensorFlow tf_cifar10 14.44705
yes PyTorch ema_cifar10 5.274105
TensorFlow tf_ema_cifar10 5.325035

To reproduce, generate 50k samples from the converted PyTorch models provided in this repo with

`python pytorch_diffusion/diffusion.py <Model> 500 100`

and with

python -c "import convert as m; m.sample_tf(500, 100, which=['cifar10', 'ema_cifar10'])"

for the original TensorFlow models.

Running conversions

The converted pytorch checkpoints are provided for download. If you want to convert them on your own, you can follow the steps described here.

Setup

This section assumes your working directory is the root of this repository. Download the pretrained TensorFlow checkpoints. It should follow the original structure,

diffusion_models_release/
  diffusion_cifar10_model/
    model.ckpt-790000.data-00000-of-00001
    model.ckpt-790000.index
    model.ckpt-790000.meta
  diffusion_lsun_bedroom_model/
    ...
  ...

Set the environment variable TFROOT to the directory where you want to store the author's repository, e.g.

export TFROOT=".."

Clone the diffusion repository,

git clone https://github.com/hojonathanho/diffusion.git ${TFROOT}/diffusion

and install their required dependencies (pip install ${TFROOT}/requirements.txt). Then add the following to your PYTHONPATH:

export PYTHONPATH=".:./scripts:${TFROOT}/diffusion:${TFROOT}/diffusion/scripts:${PYTHONPATH}"

Testing operations

To test the pytorch implementations of the required operations against their TensorFlow counterparts under random initialization and random inputs, run

python -c "import convert as m; m.test_ops()"

Converting checkpoints

To load the pretrained TensorFlow models, copy the weights into the pytorch models, check for equality on random inputs and finally save the corresponding pytorch checkpoints, run

python -c "import convert as m; m.transplant_cifar10()"
python -c "import convert as m; m.transplant_cifar10(ema=True)"
python -c "import convert as m; m.transplant_lsun_bedroom()"
python -c "import convert as m; m.transplant_lsun_bedroom(ema=True)"
python -c "import convert as m; m.transplant_lsun_cat()"
python -c "import convert as m; m.transplant_lsun_cat(ema=True)"
python -c "import convert as m; m.transplant_lsun_church()"
python -c "import convert as m; m.transplant_lsun_church(ema=True)"

Pytorch checkpoints will be saved in

diffusion_models_converted/
  diffusion_cifar10_model/
    model-790000.ckpt
  ema_diffusion_cifar10_model/
    model-790000.ckpt
  diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  ema_diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  diffusion_lsun_cat_model/
    model-1761000.ckpt
  ema_diffusion_lsun_cat_model/
    model-1761000.ckpt
  diffusion_lsun_church_model/
    model-4432000.ckpt
  ema_diffusion_lsun_church_model/
    model-4432000.ckpt

Sample TensorFlow models

To produce N samples from each of the pretrained TensorFlow models, run

python -c "import convert as m; m.sample_tf(N)"

Pass a list of model names as keyword argument which to specify which models to sample from. Samples will be saved in results/.

Owner
Patrick Esser
Patrick Esser
Robust Self-augmentation for NER with Meta-reweighting

Robust Self-augmentation for NER with Meta-reweighting

Lam chi 17 Nov 22, 2022
Hypernetwork-Ensemble Learning of Segmentation Probability for Medical Image Segmentation with Ambiguous Labels

Hypernet-Ensemble Learning of Segmentation Probability for Medical Image Segmentation with Ambiguous Labels The implementation of Hypernet-Ensemble Le

Sungmin Hong 6 Jul 18, 2022
Asymmetric Bilateral Motion Estimation for Video Frame Interpolation, ICCV2021

ABME (ICCV2021) Junheum Park, Chul Lee, and Chang-Su Kim Official PyTorch Code for "Asymmetric Bilateral Motion Estimation for Video Frame Interpolati

Junheum Park 86 Dec 28, 2022
Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

HiT-GAN Official TensorFlow Implementation HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GA

Google Research 78 Oct 31, 2022
"Inductive Entity Representations from Text via Link Prediction" @ The Web Conference 2021

Inductive entity representations from text via link prediction This repository contains the code used for the experiments in the paper "Inductive enti

Daniel Daza 45 Jan 09, 2023
Repository for the paper "From global to local MDI variable importances for random forests and when they are Shapley values"

From global to local MDI variable importances for random forests and when they are Shapley values Antonio Sutera ( Antonio Sutera 3 Feb 23, 2022

SemiNAS: Semi-Supervised Neural Architecture Search

SemiNAS: Semi-Supervised Neural Architecture Search This repository contains the code used for Semi-Supervised Neural Architecture Search, by Renqian

Renqian Luo 21 Aug 31, 2022
[ICCV'21] UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction

UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction Project Page | Paper | Supplementary | Video This reposit

331 Dec 28, 2022
Calculates carbon footprint based on fuel mix and discharge profile at the utility selected. Can create graphs and tabular output for fuel mix based on input file of series of power drawn over a period of time.

carbon-footprint-calculator Conda distribution ~/anaconda3/bin/conda install anaconda-client conda-build ~/anaconda3/bin/conda config --set anaconda_u

Seattle university Renewable energy research 7 Sep 26, 2022
Huawei Hackathon 2021 - Sweden (Stockholm)

huawei-hackathon-2021 Contributors DrakeAxelrod Challenge Requirements: python=3.8.10 Standard libraries (no importing) Important factors: Data depend

Drake Axelrod 32 Nov 08, 2022
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

Phil Wang 208 Dec 25, 2022
CCNet: Criss-Cross Attention for Semantic Segmentation (TPAMI 2020 & ICCV 2019).

CCNet: Criss-Cross Attention for Semantic Segmentation Paper Links: Our most recent TPAMI version with improvements and extensions (Earlier ICCV versi

Zilong Huang 1.3k Dec 27, 2022
GeneralOCR is open source Optical Character Recognition based on PyTorch.

Introduction GeneralOCR is open source Optical Character Recognition based on PyTorch. It makes a fidelity and useful tool to implement SOTA models on

57 Dec 29, 2022
Geometry-Aware Learning of Maps for Camera Localization (CVPR2018)

Geometry-Aware Learning of Maps for Camera Localization This is the PyTorch implementation of our CVPR 2018 paper "Geometry-Aware Learning of Maps for

NVIDIA Research Projects 321 Nov 26, 2022
Complete system for facial identity system

Complete system for facial identity system. Include one-shot model, database operation, features visualization, monitoring

4 May 02, 2022
Official code of ICCV2021 paper "Residual Attention: A Simple but Effective Method for Multi-Label Recognition"

CSRA This is the official code of ICCV 2021 paper: Residual Attention: A Simple But Effective Method for Multi-Label Recoginition Demo, Train and Vali

163 Dec 22, 2022
Single/multi view image(s) to voxel reconstruction using a recurrent neural network

3D-R2N2: 3D Recurrent Reconstruction Neural Network This repository contains the source codes for the paper Choy et al., 3D-R2N2: A Unified Approach f

Chris Choy 1.2k Dec 27, 2022
An example showing how to use jax to train resnet50 on multi-node multi-GPU

jax-multi-gpu-resnet50-example This repo shows how to use jax for multi-node multi-GPU training. The example is adapted from the resnet50 example in d

Yangzihao Wang 20 Jul 04, 2022
With this package, you can generate mixed-integer linear programming (MIP) models of trained artificial neural networks (ANNs) using the rectified linear unit (ReLU) activation function

With this package, you can generate mixed-integer linear programming (MIP) models of trained artificial neural networks (ANNs) using the rectified linear unit (ReLU) activation function. At the momen

ChemEngAI 40 Dec 27, 2022
A trusty face recognition research platform developed by Tencent Youtu Lab

Introduction TFace: A trusty face recognition research platform developed by Tencent Youtu Lab. It provides a high-performance distributed training fr

Tencent 956 Jan 01, 2023