Release for Improved Denoising Diffusion Probabilistic Models

Overview

improved-diffusion

This is the codebase for Improved Denoising Diffusion Probabilistic Models.

Usage

This section of the README walks through how to train and sample from a model.

Installation

Clone this repository and navigate to it in your terminal. Then run:

pip install -e .

This should install the improved_diffusion python package that the scripts depend on.

Preparing Data

The training code reads images from a directory of image files. In the datasets folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10.

For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names).

The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass --data_dir path/to/images to the training script, and it will take care of the rest.

Training

To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline:

MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"

Here are some changes we experiment with, and how to set them in the flags:

  • Learned sigmas: add --learn_sigma True to MODEL_FLAGS
  • Cosine schedule: change --noise_schedule linear to --noise_schedule cosine
  • Reweighted VLB: add --use_kl True to DIFFUSION_FLAGS and add --schedule_sampler loss-second-moment to TRAIN_FLAGS.
  • Class-conditional: add --class_cond True to MODEL_FLAGS.

Once you have setup your hyper-parameters, you can run an experiment like so:

python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

You may also want to train in a distributed manner. In this case, run the same command with mpiexec:

mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

When training in a distributed manner, you must manually divide the --batch_size argument by the number of ranks. In lieu of distributed training, you may use --microbatch 16 (or --microbatch 1 in extreme memory-limited cases) to reduce memory usage.

The logs and saved models will be written to a logging directory determined by the OPENAI_LOGDIR environment variable. If it is not set, then a temporary directory will be created in /tmp.

Sampling

The above training script saves checkpoints to .pt files in the logging directory. These checkpoints will have names like ema_0.9999_200000.pt and model200000.pt. You will likely want to sample from the EMA models, since those produce much better samples.

Once you have a path to your model, you can generate a large batch of samples like so:

python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS

Again, this will save results to a logging directory. Samples are saved as a large npz file, where arr_0 in the file is a large batch of samples.

Just like for training, you can run image_sample.py through MPI to use multiple GPUs and machines.

You can change the number of sampling steps using the --timestep_respacing argument. For example, --timestep_respacing 250 uses 250 steps to sample. Passing --timestep_respacing ddim250 is similar, but uses the uniform stride from the DDIM paper rather than our stride.

To sample using DDIM, pass --use_ddim True.

Owner
OpenAI
OpenAI
Write Streamlit apps using Notion! (Prototype)

Streamlit + Notion test app Write Streamlit apps using Notion! ☠️ IMPORTANT: This is just a little prototype I made to play with some ideas. Not meant

Thiago Teixeira 22 Sep 08, 2022
Quick script for automatically extracting syscall numbers for an OS

Syscalls-Extractor Quick script for automatically extracting syscall numbers for an OS $ python3 .\syscalls-extractor.py --help usage: syscalls-extrac

m0rv4i 54 Feb 10, 2022
KiCad bus length matching script.

KiBus length matching script This script implements way to monitor multiple nets, combined into a bus that needs to be length matched

Piotr Esden-Tempski 22 Mar 17, 2022
A Python tool to check ASS subtitles for common mistakes and errors.

A Python tool to check ASS subtitles for common mistakes and errors.

1 Dec 18, 2021
Домашние задания, выполненные на 3ем семестре РТУ МИРЭА, по дисциплине

ДЗ по курсу "Конфигурационное управление" в РТУ МИРЭА Описание В данном репозитории находятся домашние задания, выполненные на 3ем семестре РТУ МИРЭА,

Semyon Esaev 4 Dec 22, 2022
Mdisk - 🚧 On Construction 🚧

Mdisk Install For Package pip install mdisk pip install git+https://github.com/HeimanPictures/Mdisk.git Usage You can use this as python module or via

AkKiL 6 Aug 08, 2022
This is the repo for Uncertainty Quantification 360 Toolkit.

UQ360 The Uncertainty Quantification 360 (UQ360) toolkit is an open-source Python package that provides a diverse set of algorithms to quantify uncert

International Business Machines 207 Dec 30, 2022
Awesome Casino is simple offline casino made on python.

Awesome-Casino Awesome Casino is simple offline casino made on python. I found bug, what can i do? If you find any bug or want to suggest any idea, al

Herman 1 Feb 04, 2022
Adds a Bake node to Blender's shader node system

Bake to Target This Blender Addon adds a new shader node type capable of reducing the texture-bake step to a single button press. Please note that thi

Thomas 8 Oct 04, 2022
Spartan implementation of H.O.T.T.

Down The Path I was walking down the line, Trying to find some peace of mind. Then I saw you, You were takin' it slow, And walkin' it one step at a ti

Trebor Huang 25 Aug 05, 2022
Push a record and you will receive a email when that date

Push a record and you will receive a email when that date

5 Nov 28, 2022
kodi addon 115网盘

plugin.video.115 kodi addon 115网盘 插件,需要kodi 18以上版本,原码播放需配合 https://github.com/feelfar/115proxy-for-kodi 使用 安装 HEAD 由于release包尚未释出,可直接下载源代码zip包

109 Dec 29, 2022
My solution for a MARL problem on a Grid Environment with Q-tables.

To run the project, run: conda create --name env python=3.7 pip install -r requirements.txt python run.py To-do: Add direction to the state space Take

Merve Noyan 12 Dec 25, 2021
Chemical equation balancer

Chemical equation balancer Balance your chemical equations with ease! Installation $ git clone

Marijan Smetko 4 Nov 26, 2022
Lags valorant servers by rapidly picking up and throwing shorties.

Lags valorant servers by rapidly picking up and throwing shorties.

Eric Still 9 Dec 30, 2021
CupScript is a simple programing language made with python

CupScript CupScript is a simple programming language made with python It includes some basic functions, variables, loops, and some other built in func

FUSEN 23 Dec 29, 2022
Labspy06 With Python

Labspy06 Profil Nama : Nafal mumtaz fuadi Nim : 312110457 Kelas : T1.21.A.2 Latihan 1 Ubahlah kode dibawah ini menjadi fungsi menggunakan lambda impor

Mas Nafal 1 Dec 12, 2021
A simple way to read and write LAPS passwords from linux.

A simple way to read and write LAPS passwords from linux. This script is a python setter/getter for property ms-Mcs-AdmPwd used by LAPS inspired by @s

Podalirius 36 Dec 09, 2022
JHBuild is a tool designed to ease building collections of source packages, called “modules”.

JHBuild README JHBuild is a tool designed to ease building collections of source packages, called “modules”. JHBuild was originally written for buildi

GNOME Github Mirror 46 Nov 22, 2022
Script to calculate delegator epoch returns for all pillars

znn_delegator_calculator Script to calculate estimated delegator epoch returns for all Pillars, so you can delegate to the best one. You can find me o

2 Dec 03, 2021