Collapse by Conditioning: Training Class-conditional GANs with Limited Data

Overview

Collapse by Conditioning: Training Class-conditional GANs with Limited Data

Mohamad Shahbazi, Martin Danelljan, Danda P. Paudel, Luc Van Gool
Paper: https://openreview.net/forum?id=7TZeCsNOUB_

Teaser image

Abstract

Class-conditioning offers a direct means of controlling a Generative Adversarial Network (GAN) based on a discrete input variable. While necessary in many applications, the additional information provided by the class labels could even be expected to benefit the training of the GAN itself. Contrary to this belief, we observe that class-conditioning causes mode collapse in limited data settings, where unconditional learning leads to satisfactory generative ability. Motivated by this observation, we propose a training strategy for conditional GANs (cGANs) that effectively prevents the observed mode-collapse by leveraging unconditional learning. Our training strategy starts with an unconditional GAN and gradually injects conditional information into the generator and the objective function. The proposed method for training cGANs with limited data results not only in stable training but also in generating high-quality images, thanks to the early-stage exploitation of the shared information across classes. We analyze the aforementioned mode collapse problem in comprehensive experiments on four datasets. Our approach demonstrates outstanding results compared with state-of-the-art methods and established baselines.

Overview

  1. Requirements
  2. Getting Started
  3. Dataset Prepration
  4. Training
  5. Evaluation and Logging
  6. Contact
  7. How to Cite

Requirements

  • Linux and Windows are supported, but Linux is recommended for performance and compatibility reasons.
  • For the batch size of 64, we have used 4 NVIDIA GeForce RTX 2080 Ti GPUs (each having 11 GiB of memory).
  • 64-bit Python 3.7 and PyTorch 1.7.1. See https://pytorch.org/ for PyTorch installation instructions.
  • CUDA toolkit 11.0 or later. Use at least version 11.1 if running on RTX 3090. (Why is a separate CUDA toolkit installation required? See comments of this Github issue.)
  • Python libraries: pip install wandb click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3.
  • This project uses Weights and Biases for visualization and logging. In addition to installing W&B (included in the command above), you need to create a free account on W&B website. Then, you must login to your account in the command line using the command โ€โ€โ€wandb login (The login information will be asked after running the command).
  • Docker users: use the provided Dockerfile by StyleGAN2+ADA (./Dockerfile) to build an image with the required library dependencies.

The code relies heavily on custom PyTorch extensions that are compiled on the fly using NVCC. On Windows, the compilation requires Microsoft Visual Studio. We recommend installing Visual Studio Community Edition and adding it into PATH using "C:\Program Files (x86)\Microsoft Visual Studio\ \Community\VC\Auxiliary\Build\vcvars64.bat" .

Getting Started

The code for this project is based on the Pytorch implementation of StyleGAN2+ADA. Please first read the instructions provided for StyleGAN2+ADA. Here, we mainly provide the additional details required to use our method.

For a quick start, we have provided example scripts in ./scripts, as well as an example dataset (a tar file containing a subset of ImageNet Carnivores dataset used in the paper) in ./datasets. Note that the scripts do not include the command for activating python environments. Moreover, the paths for the dataset and output directories can be modified in the scripts based on your own setup.

The following command runs a script that extracts the tar file and creates a ZIP file in the same directory.

bash scripts/prepare_dataset_ImageNetCarnivores_20_100.sh

The ZIP file is later used for training and evaluation. For more details on how to use your custom datasets, see Dataset Prepration.

Following command runs a script that trains the model using our method with default hyper-parameters:

bash scripts/train_ImageNetCarnivores_20_100.sh

For more details on how to use your custom datasets, see Training

To calculate the evaluation metrics on a pretrained model, use the following command:

bash scripts/inference_metrics_ImageNetCarnivores_20_100.sh

Outputs from the training and inferenve commands are by default placed under out/, controlled by --outdir. Downloaded network pickles are cached under $HOME/.cache/dnnlib, which can be overridden by setting the DNNLIB_CACHE_DIR environment variable. The default PyTorch extension build directory is $HOME/.cache/torch_extensions, which can be overridden by setting TORCH_EXTENSIONS_DIR.

Dataset Prepration

Datasets are stored as uncompressed ZIP archives containing uncompressed PNG files and a metadata file dataset.json for labels.

Custom datasets can be created from a folder containing images (each sub-directory containing images of one class in case of multi-class datasets) using dataset_tool.py; Here is an example of how to convert the dataset folder to the desired ZIP file:

python dataset_tool.py --source=datasets/ImageNet_Carnivores_20_100 --dest=datasets/ImageNet_Carnivores_20_100.zip --transform=center-crop --width=128 --height=128

The above example reads the images from the image folder provided by --src, resizes the images to the sizes provided by --width and --height, and applys the transform center-crop to them. The resulting images along with the metadata (label information) are stored as a ZIP file determined by --dest. see python dataset_tool.py --help for more information. See StyleGAN2+ADA instructions for more details on specific datasets or Legacy TFRecords datasets .

The created ZIP file can be passed to the training and evaluation code using --data argument.

Training

Training new networks can be done using train.py. In order to perform the training using our method, the argument --cond should be set to 1, so that the training is done conditionally. In addition, the start and the end of the transition from unconditional to conditional training should be specified using the arguments t_start_kimg and --t_end_kimg. Here is an example training command:

python train.py --outdir=./out/ \
--data=datasets/ImageNet_Carnivores_20_100.zip \
--cond=1 --t_start_kimg=2000  --t_end_kimg=4000  \
--gpus=4 \
--cfg=auto --mirror=1 \
--metrics=fid50k_full,kid50k_full

See StyleGAN2+ADA instructions for more details on the arguments, configurations amd hyper-parammeters. Please refer to python train.py --help for the full list of arguments.

Note: Our code currently can be used only for unconditional or transitional training. For the original conditional training, you can use the original implementation StyleGAN2+ADA.

Evaluation and Logging

By default, train.py automatically computes FID for each network pickle exported during training. More metrics can be added to the argument --metrics (as a comma-seperated list). To monitor the training, you can inspect the log.txt an JSON files (e.g. metric-fid50k_full.jsonl for FID) saved in the ouput directory. Alternatively, you can inspect WandB or Tensorboard logs (By default, WandB creates the logs under the project name "Transitional-cGAN", which can be accessed in your account on the website).

When desired, the automatic computation can be disabled with --metrics=none to speed up the training slightly (3%โ€“9%). Additional metrics can also be computed after the training:

# Previous training run: look up options automatically, save result to JSONL file.
python calc_metrics.py --metrics=pr50k3_full \
    --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl

# Pre-trained network pickle: specify dataset explicitly, print result to stdout.
python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \
    --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl

The first example looks up the training configuration and performs the same operation as if --metrics=pr50k3_full had been specified during training. The second example downloads a pre-trained network pickle, in which case the values of --mirror and --data must be specified explicitly.

See StyleGAN2+ADA instructions for more details on the available metrics.

Contact

For any questions, suggestions, or issues with the code, please contact Mohamad Shahbazi at [email protected]

How to Cite

@inproceedings{
shahbazi2022collapse,
title={Collapse by Conditioning: Training Class-conditional {GAN}s with Limited Data},
author={Shahbazi, Mohamad and Danelljan, Martin and Pani Paudel, Danda and Van Gool, Luc},
booktitle={The Tenth International Conference on Learning Representations },
year={2022},
url={https://openreview.net/forum?id=7TZeCsNOUB_}
Owner
Mohamad Shahbazi
Ph.D. student at Computer Vision Lab, ETH Zurich || Interested in Machine Learning and its Applications in Computer Vision, NLP and Healthcare
Mohamad Shahbazi
CSPML (crystal structure prediction with machine learning-based element substitution)

CSPML (crystal structure prediction with machine learning-based element substitution) CSPML is a unique methodology for the crystal structure predicti

8 Dec 20, 2022
This repo contains implementation of different architectures for emotion recognition in conversations.

Emotion Recognition in Conversations Updates ๐Ÿ”ฅ ๐Ÿ”ฅ ๐Ÿ”ฅ Date Announcements 03/08/2021 ๐ŸŽ† ๐ŸŽ† We have released a new dataset M2H2: A Multimodal Multiparty

Deep Cognition and Language Research (DeCLaRe) Lab 1k Dec 30, 2022
Styled text-to-drawing synthesis method. Featured at the 2021 NeurIPS Workshop on Machine Learning for Creativity and Design

Styled text-to-drawing synthesis method. Featured at the 2021 NeurIPS Workshop on Machine Learning for Creativity and Design

Peter Schaldenbrand 247 Dec 23, 2022
DNA-RECON { Automatic Web Reconnaissance Tool }

ABOUT TOOL : DNA-RECON is an automatic web reconnaissance tool written in python. This tool made for reconnaissance and information gathering with an

NIKUNJ BHATT 25 Aug 11, 2021
Pretraining on Dynamic Graph Neural Networks

Pretraining on Dynamic Graph Neural Networks Our article is PT-DGNN and the code is modified based on GPT-GNN Requirements python 3.6 Ubuntu 18.04.5 L

7 Dec 17, 2022
A web-based application for quick, scalable, and automated hyperparameter tuning and stacked ensembling in Python.

Xcessiv Xcessiv is a tool to help you create the biggest, craziest, and most excessive stacked ensembles you can think of. Stacked ensembles are simpl

Reiichiro Nakano 1.3k Nov 17, 2022
Official code of the paper "ReDet: A Rotation-equivariant Detector for Aerial Object Detection" (CVPR 2021)

ReDet: A Rotation-equivariant Detector for Aerial Object Detection ReDet: A Rotation-equivariant Detector for Aerial Object Detection (CVPR2021), Jiam

csuhan 334 Dec 23, 2022
[CVPR 2021] A Peek Into the Reasoning of Neural Networks: Interpreting with Structural Visual Concepts

Visual-Reasoning-eXplanation [CVPR 2021 A Peek Into the Reasoning of Neural Networks: Interpreting with Structural Visual Concepts] Project Page | Vid

Andy_Ge 54 Dec 21, 2022
Unofficial PyTorch code for BasicVSR

Dependencies and Installation The code is based on BasicSR, Please install the BasicSR framework first. Pytorch=1.51 Training cd ./code CUDA_VISIBLE_

Long 59 Dec 06, 2022
Official Implementation for Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation

Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation We present a generic image-to-image translation framework, pixel2style2pixel (pSp

2.8k Dec 30, 2022
Dense Contrastive Learning (DenseCL) for self-supervised representation learning, CVPR 2021.

Dense Contrastive Learning for Self-Supervised Visual Pre-Training This project hosts the code for implementing the DenseCL algorithm for se

Xinlong Wang 491 Jan 03, 2023
A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval

CLIP4CMR A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval The original data and pre-calculate

24 Dec 26, 2022
Code for "Multi-Compound Transformer for Accurate Biomedical Image Segmentation"

News The code of MCTrans has been released. if you are interested in contributing to the standardization of the medical image analysis community, plea

97 Jan 05, 2023
CLIP-GEN: Language-Free Training of a Text-to-Image Generator with CLIP

CLIP-GEN [็ฎ€ไฝ“ไธญๆ–‡][English] ๆœฌ้กน็›ฎๅœจ่ค็ซไบŒๅท้›†็พคไธŠ็”จ PyTorch ๅฎž็Žฐไบ†่ฎบๆ–‡ ใ€ŠCLIP-GEN: Language-Free Training of a Text-to-Image Generator with CLIPใ€‹ใ€‚ CLIP-GEN ๆ˜ฏไธ€ไธช Language-F

75 Dec 29, 2022
Pytorch Implementation of LNSNet for Superpixel Segmentation

LNSNet Overview Official implementation of Learning the Superpixel in a Non-iterative and Lifelong Manner (CVPR'21) Learning Strategy The proposed LNS

42 Oct 11, 2022
๐Ÿš— INGI Dakar 2K21 - Be the first one on the finish line ! ๐Ÿš—

๐Ÿš— INGI Dakar 2K21 - Be the first one on the finish line ! ๐Ÿš— This year's first semester Club Info challenge will put you at the head of a car racing

ClubINFO INGI (UCLouvain) 6 Dec 10, 2021
Official codes: Self-Supervised Learning by Estimating Twin Class Distribution

TWIST: Self-Supervised Learning by Estimating Twin Class Distributions Codes and pretrained models for TWIST: @article{wang2021self, title={Self-Sup

Bytedance Inc. 85 Dec 15, 2022
A simple and lightweight genetic algorithm for optimization of any machine learning model

geneticml This package contains a simple and lightweight genetic algorithm for optimization of any machine learning model. Installation Use pip to ins

Allan Barcelos 8 Aug 10, 2022
Self-Supervised Monocular 3D Face Reconstruction by Occlusion-Aware Multi-view Geometry Consistency[ECCV 2020]

Self-Supervised Monocular 3D Face Reconstruction by Occlusion-Aware Multi-view Geometry Consistency(ECCV 2020) This is an official python implementati

304 Jan 03, 2023
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022