How to Train a GAN? Tips and tricks to make GANs work

Related tags

Deep Learningganhacks
Overview

(this list is no longer maintained, and I am not sure how relevant it is in 2020)

How to Train a GAN? Tips and tricks to make GANs work

While research in Generative Adversarial Networks (GANs) continues to improve the fundamental stability of these models, we use a bunch of tricks to train them and make them stable day to day.

Here are a summary of some of the tricks.

Here's a link to the authors of this document

If you find a trick that is particularly useful in practice, please open a Pull Request to add it to the document. If we find it to be reasonable and verified, we will merge it in.

1. Normalize the inputs

  • normalize the images between -1 and 1
  • Tanh as the last layer of the generator output

2: A modified loss function

In GAN papers, the loss function to optimize G is min (log 1-D), but in practice folks practically use max log D

  • because the first formulation has vanishing gradients early on
  • Goodfellow et. al (2014)

In practice, works well:

  • Flip labels when training generator: real = fake, fake = real

3: Use a spherical Z

  • Dont sample from a Uniform distribution

cube

  • Sample from a gaussian distribution

sphere

4: BatchNorm

  • Construct different mini-batches for real and fake, i.e. each mini-batch needs to contain only all real images or all generated images.
  • when batchnorm is not an option use instance normalization (for each sample, subtract mean and divide by standard deviation).

batchmix

5: Avoid Sparse Gradients: ReLU, MaxPool

  • the stability of the GAN game suffers if you have sparse gradients
  • LeakyReLU = good (in both G and D)
  • For Downsampling, use: Average Pooling, Conv2d + stride
  • For Upsampling, use: PixelShuffle, ConvTranspose2d + stride

6: Use Soft and Noisy Labels

  • Label Smoothing, i.e. if you have two target labels: Real=1 and Fake=0, then for each incoming sample, if it is real, then replace the label with a random number between 0.7 and 1.2, and if it is a fake sample, replace it with 0.0 and 0.3 (for example).
    • Salimans et. al. 2016
  • make the labels the noisy for the discriminator: occasionally flip the labels when training the discriminator

7: DCGAN / Hybrid Models

  • Use DCGAN when you can. It works!
  • if you cant use DCGANs and no model is stable, use a hybrid model : KL + GAN or VAE + GAN

8: Use stability tricks from RL

  • Experience Replay
    • Keep a replay buffer of past generations and occassionally show them
    • Keep checkpoints from the past of G and D and occassionaly swap them out for a few iterations
  • All stability tricks that work for deep deterministic policy gradients
  • See Pfau & Vinyals (2016)

9: Use the ADAM Optimizer

  • optim.Adam rules!
    • See Radford et. al. 2015
  • Use SGD for discriminator and ADAM for generator

10: Track failures early

  • D loss goes to 0: failure mode
  • check norms of gradients: if they are over 100 things are screwing up
  • when things are working, D loss has low variance and goes down over time vs having huge variance and spiking
  • if loss of generator steadily decreases, then it's fooling D with garbage (says martin)

11: Dont balance loss via statistics (unless you have a good reason to)

  • Dont try to find a (number of G / number of D) schedule to uncollapse training
  • It's hard and we've all tried it.
  • If you do try it, have a principled approach to it, rather than intuition

For example

while lossD > A:
  train D
while lossG > B:
  train G

12: If you have labels, use them

  • if you have labels available, training the discriminator to also classify the samples: auxillary GANs

13: Add noise to inputs, decay over time

14: [notsure] Train discriminator more (sometimes)

  • especially when you have noise
  • hard to find a schedule of number of D iterations vs G iterations

15: [notsure] Batch Discrimination

  • Mixed results

16: Discrete variables in Conditional GANs

  • Use an Embedding layer
  • Add as additional channels to images
  • Keep embedding dimensionality low and upsample to match image channel size

17: Use Dropouts in G in both train and test phase

Authors

  • Soumith Chintala
  • Emily Denton
  • Martin Arjovsky
  • Michael Mathieu
Owner
Soumith Chintala
/\︿╱\ _________________________________ \0_ 0 /╱\╱____________________________ \▁︹_/
Soumith Chintala
The offcial repository for 'CharacterBERT and Self-Teaching for Improving the Robustness of Dense Retrievers on Queries with Typos', SIGIR2022

CharacterBERT-DR The offcial repository for CharacterBERT and Self-Teaching for Improving the Robustness of Dense Retrievers on Queries with Typos, Sh

ielab 11 Nov 15, 2022
Code for paper "Do Language Models Have Beliefs? Methods for Detecting, Updating, and Visualizing Model Beliefs"

This is the codebase for the paper: Do Language Models Have Beliefs? Methods for Detecting, Updating, and Visualizing Model Beliefs Directory Structur

Peter Hase 19 Aug 21, 2022
571 Dec 25, 2022
DrWhy is the collection of tools for eXplainable AI (XAI). It's based on shared principles and simple grammar for exploration, explanation and visualisation of predictive models.

Responsible Machine Learning With Great Power Comes Great Responsibility. Voltaire (well, maybe) How to develop machine learning models in a responsib

Model Oriented 590 Dec 26, 2022
Rendering color and depth images for ShapeNet models.

Color & Depth Renderer for ShapeNet This library includes the tools for rendering multi-view color and depth images of ShapeNet models. Physically bas

Yinyu Nie 41 Dec 19, 2022
The description of FMFCC-A (audio track of FMFCC) dataset and Challenge resluts.

FMFCC-A This project is the description of FMFCC-A (audio track of FMFCC) dataset and Challenge resluts. The FMFCC-A dataset is shared through BaiduCl

18 Dec 24, 2022
Film review classification

Film review classification Решение задачи классификации отзывов на фильмы на положительные и отрицательные с помощью рекуррентных нейронных сетей 1. З

Nikita Dukin 3 Jan 21, 2022
Official implementation of particle-based models (GNS and DPI-Net) on the Physion dataset.

Physion: Evaluating Physical Prediction from Vision in Humans and Machines [paper] Daniel M. Bear, Elias Wang, Damian Mrowca, Felix J. Binder, Hsiao-Y

Hsiao-Yu Fish Tung 18 Dec 19, 2022
Unofficial implementation of Point-Unet: A Context-Aware Point-Based Neural Network for Volumetric Segmentation

Point-Unet This is an unofficial implementation of the MICCAI 2021 paper Point-Unet: A Context-Aware Point-Based Neural Network for Volumetric Segment

Namt0d 9 Dec 07, 2022
REBEL: Relation Extraction By End-to-end Language generation

REBEL: Relation Extraction By End-to-end Language generation This is the repository for the Findings of EMNLP 2021 paper REBEL: Relation Extraction By

Babelscape 222 Jan 06, 2023
The Deep Learning with Julia book, using Flux.jl.

Deep Learning with Julia DL with Julia is a book about how to do various deep learning tasks using the Julia programming language and specifically the

Logan Kilpatrick 67 Dec 25, 2022
[UNMAINTAINED] Automated machine learning for analytics & production

auto_ml Automated machine learning for production and analytics Installation pip install auto_ml Getting started from auto_ml import Predictor from au

Preston Parry 1.6k Jan 02, 2023
An implementation of the AlphaZero algorithm for Gomoku (also called Gobang or Five in a Row)

AlphaZero-Gomoku This is an implementation of the AlphaZero algorithm for playing the simple board game Gomoku (also called Gobang or Five in a Row) f

Junxiao Song 2.8k Dec 26, 2022
Codes of paper "Unseen Object Amodal Instance Segmentation via Hierarchical Occlusion Modeling"

Unseen Object Amodal Instance Segmentation (UOAIS) Seunghyeok Back, Joosoon Lee, Taewon Kim, Sangjun Noh, Raeyoung Kang, Seongho Bak, Kyoobin Lee This

GIST-AILAB 92 Dec 13, 2022
Experiments for Operating Systems Lab (ETCS-352)

Operating Systems Lab (ETCS-352) Experiments for Operating Systems Lab (ETCS-352) performed by me in 2021 at uni. All codes are written by me except t

Deekshant Wadhwa 0 Sep 06, 2022
A new test set for ImageNet

ImageNetV2 The ImageNetV2 dataset contains new test data for the ImageNet benchmark. This repository provides associated code for assembling and worki

186 Dec 18, 2022
The Official PyTorch Implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 spotlight paper)

Official PyTorch implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 Spotlight Paper) Zhisheng

NVIDIA Research Projects 45 Dec 26, 2022
[AAAI2021] The source code for our paper 《Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion》.

DSM The source code for paper Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion Project Website; Datasets li

Jinpeng Wang 114 Oct 16, 2022
ScriptProfilerPy - Module to visualize where your python script is slow

ScriptProfiler helps you track where your code is slow It provides: Code lines t

Lucas BLP 3 Jun 02, 2022
Parametric Contrastive Learning (ICCV2021)

Parametric-Contrastive-Learning This repository contains the implementation code for ICCV2021 paper: Parametric Contrastive Learning (https://arxiv.or

DV Lab 156 Dec 21, 2022