Code for the paper "Training GANs with Stronger Augmentations via Contrastive Discriminator" (ICLR 2021)

Overview

Training GANs with Stronger Augmentations via Contrastive Discriminator (ICLR 2021)

This repository contains the code for reproducing the paper: Training GANs with Stronger Augmentations via Contrastive Discriminator by Jongheon Jeong and Jinwoo Shin.

TL;DR: We propose a novel discriminator of GAN showing that contrastive representation learning, e.g., SimCLR, and GAN can benefit each other when they are jointly trained.

Demo

Qualitative comparison of unconditional generations from GANs on high-resoultion, yet limited-sized datasets of AFHQ-Dog (4739 samples), AFHQ-Cat (5153 samples) and AFHQ-Wild (4738 samples) datasets.

Overview

Teaser

An overview of Contrastive Discriminator (ContraD). The representation of ContraD is not learned from the discriminator loss (L_dis), but from two contrastive losses (L+_con and L-_con), each is for the real and fake samples, respectively. The actual discriminator that minimizes L_dis is simply a 2-layer MLP head upon the learned contrastive representation.

Dependencies

Currently, the following environment has been confirmed to run the code:

  • python >= 3.6
  • pytorch >= 1.6.0 (See https://pytorch.org/ for the detailed installation)
  • tensorflow-gpu == 1.14.0 to run test_tf_inception.py for FID/IS evaluations
  • Other requirements can be found in environment.yml (for conda users) or environment_pip.txt (for pip users)
#### Install dependencies via conda.
# The file also includes `pytorch`, `tensorflow-gpu=1.14`, and `cudatoolkit=10.1`.
# You may have to set the correct version of `cudatoolkit` compatible to your system.
# This command creates a new conda environment named `contrad`.
conda env create -f environment.yml

#### Install dependencies via pip.
# It assumes `pytorch` and `tensorflow-gpu` are already installed in the current environment.
pip install -r environment_pip.txt

Preparing datasets

By default, the code assumes that all the datasets are placed under data/. You can change this path by setting the $DATA_DIR environment variable.

CIFAR-10/100 can be automatically downloaded by running any of the provided training scripts.

CelebA-HQ-128:

  1. Download the CelebA-HQ dataset and extract it under $DATA_DIR.
  2. Run third_party/preprocess_celeba_hq.py to resize and split the 1024x1024 images in $DATA_DIR/CelebAMask-HQ/CelebA-HQ-img:
    python third_party/preprocess_celeba_hq.py
    

AFHQ datasets:

  1. Download the AFHQ dataset and extract it under $DATA_DIR.
  2. One has to reorganize the directories in $DATA_DIR/afhq to make it compatible with torchvision.datasets.ImageFolder. Please refer the detailed file structure provided in below.

The structure of $DATA_DIR should be roughly like as follows:

$DATA_DIR
├── cifar-10-batches-py   # CIFAR-10
├── cifar-100-python      # CIFAR-100
├── CelebAMask-HQ         # CelebA-HQ-128
│   ├── CelebA-128-split  # Resized to 128x128 from `CelebA-HQ-img`
│   │   ├── train
│   │   │   └── images
│   │   │       ├── 0.jpg
│   │   │       └── ...
│   │   └── test
│   ├── CelebA-HQ-img     # Original 1024x1024 images
│   ├── CelebA-HQ-to-CelebA-mapping.txt
│   └── README.txt
└── afhq                  # AFHQ datasets
    ├── cat
    │   ├── train
    │   │   └── images
    │   │       ├── flickr_cat_00xxxx.jpg
    │   │       └── ...
    │   └── val
    ├── dog
    └── wild

Scripts

Training Scripts

We provide training scripts to reproduce the results in train_*.py, as listed in what follows:

File Description
train_gan.py Train a GAN model other than StyleGAN2. DistributedDataParallel supported.
train_stylegan2.py Train a StyleGAN2 model. It additionally implements the details of StyleGAN2 training, e.g., R1 regularization and EMA. DataParallel supported.
train_stylegan2_contraD.py Training script optimized for StyleGAN2 + ContraD. It runs faster especially on high-resolution datasets, e.g., 512x512 AFHQ. DataParallel supported.

The samples below demonstrate how to run each script to train GANs with ContraD. More instructions to reproduce our experiments, e.g., other baselines, can be found in EXPERIMENTS.md. One can modify CUDA_VISIBLE_DEVICES to further specify GPU number(s) to work on.

# SNDCGAN + ContraD on CIFAR-10
CUDA_VISIBLE_DEVICES=0 python train_gan.py configs/gan/cifar10/c10_b512.gin sndcgan \
--mode=contrad --aug=simclr --use_warmup

# StyleGAN2 + ContraD on CIFAR-10 - it is OK to simply use `train_stylegan2.py` even with ContraD
python train_stylegan2.py configs/gan/stylegan2/c10_style64.gin stylegan2 \
--mode=contrad --aug=simclr --lbd_r1=0.1 --no_lazy --halflife_k=1000 --use_warmup

# Nevertheless, StyleGAN2 + ContraD can be trained more efficiently with `train_stylegan2_contraD.py` 
python train_stylegan2_contraD.py configs/gan/stylegan2/afhq_dog_style64.gin stylegan2_512 \
--mode=contrad --aug=simclr_hq --lbd_r1=0.5 --halflife_k=20 --use_warmup \
--evaluate_every=5000 --n_eval_avg=1 --no_gif 

Testing Scripts

  • The script test_gan_sample.py generates and saves random samples from a pre-trained generator model into *.jpg files. For example,

    CUDA_VISIBLE_DEVICES=0 python test_gan_sample.py PATH/TO/G.pt sndcgan --n_samples=10000
    

    will load the generator stored at PATH/TO/G.pt, generate n_samples=10000 samples from it, and save them under PATH/TO/samples_*/.

  • The script test_gan_sample_cddls.py additionally takes the discriminator, and a linear evaluation head obtained from test_lineval.py to perform class-conditional cDDLS. For example,

    CUDA_VISIBLE_DEVICES=0 python test_gan_sample_cddls.py LOGDIR PATH/TO/LINEAR.pth.tar sndcgan
    

    will load G and D stored in LOGDIR, the linear head stored at PATH/TO/LINEAR.pth.tar, and save the generated samples from cDDLS under LOGDIR/samples_cDDLS_*/.

  • The script test_lineval.py performs linear evaluation for a given pre-trained discriminator model stored at model_path:

    CUDA_VISIBLE_DEVICES=0 python test_lineval.py PATH/TO/D.pt sndcgan
    
  • The script test_tf_inception.py computes Fréchet Inception distance (FID) and Inception score (IS) with TensorFlow backend using the original code of FID available at https://github.com/bioinf-jku/TTUR. tensorflow-gpu <= 1.14.0 is required to run this script. It takes a directory of generated samples (e.g., via test_gan_sample.py) and an .npz of pre-computed statistics:

    python test_tf_inception.py PATH/TO/GENERATED/IMAGES/ PATH/TO/STATS.npz --n_imgs=10000 --gpu=0 --verbose
    

    A pre-computed statistics file per dataset can be either found in http://bioinf.jku.at/research/ttur/, or manually computed - you can refer third_party/tf/examples for the sample scripts to this end.

Citation

@inproceedings{jeong2021contrad,
  title={Training {GAN}s with Stronger Augmentations via Contrastive Discriminator},
  author={Jongheon Jeong and Jinwoo Shin},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=eo6U4CAwVmg}
}
Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Improving evidential deep learning via multi task learning It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task le

deargen 11 Nov 19, 2022
Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index.

TechSEO Crawler Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index. Play with the r

JR Oakes 57 Nov 24, 2022
The fastest way to visualize GradCAM with your Keras models.

VizGradCAM VizGradCam is the fastest way to visualize GradCAM in Keras models. GradCAM helps with providing visual explainability of trained models an

58 Nov 19, 2022
Pytorch implementation for RelTransformer

RelTransformer Our Architecture This is a Pytorch implementation for RelTransformer The implementation for Evaluating on VG200 can be found here Requi

Vision CAIR Research Group, KAUST 21 Nov 22, 2022
The Rich Get Richer: Disparate Impact of Semi-Supervised Learning

The Rich Get Richer: Disparate Impact of Semi-Supervised Learning Preprocess file of the dataset used in implicit sub-populations: (Demographic groups

<a href=[email protected]"> 4 Oct 14, 2022
Codes for TS-CAM: Token Semantic Coupled Attention Map for Weakly Supervised Object Localization.

TS-CAM: Token Semantic Coupled Attention Map for Weakly SupervisedObject Localization This is the official implementaion of paper TS-CAM: Token Semant

vasgaowei 112 Jan 02, 2023
공공장소에서 눈만 돌리면 CCTV가 보인다는 말이 과언이 아닐 정도로 CCTV가 우리 생활에 깊숙이 자리 잡았습니다.

ObsCare_Main 소개 공공장소에서 눈만 돌리면 CCTV가 보인다는 말이 과언이 아닐 정도로 CCTV가 우리 생활에 깊숙이 자리 잡았습니다. CCTV의 대수가 급격히 늘어나면서 관리와 효율성 문제와 더불어, 곳곳에 설치된 CCTV를 개별 관제하는 것으로는 응급 상

5 Jul 07, 2022
免费获取http代理并生成proxifier配置文件

freeproxy 免费获取http代理并生成proxifier配置文件 公众号:台下言书 工具说明:https://mp.weixin.qq.com/s?__biz=MzIyNDkwNjQ5Ng==&mid=2247484425&idx=1&sn=56ccbe130822aa35038095317

说书人 32 Mar 25, 2022
A PyTorch implementation of NeRF (Neural Radiance Fields) that reproduces the results.

NeRF-pytorch NeRF (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are

Yen-Chen Lin 3.2k Jan 08, 2023
Customised to detect objects automatically by a given model file(onnx)

LabelImg LabelImg is a graphical image annotation tool. It is written in Python and uses Qt for its graphical interface. Annotations are saved as XML

Heeone Lee 1 Jun 07, 2022
This is an official implementation for "Self-Supervised Learning with Swin Transformers".

Self-Supervised Learning with Vision Transformers By Zhenda Xie*, Yutong Lin*, Zhuliang Yao, Zheng Zhang, Qi Dai, Yue Cao and Han Hu This repo is the

Swin Transformer 529 Jan 02, 2023
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
PixelPyramids: Exact Inference Models from Lossless Image Pyramids (ICCV 2021)

PixelPyramids: Exact Inference Models from Lossless Image Pyramids This repository contains the PyTorch implementation of the paper PixelPyramids: Exa

Visual Inference Lab @TU Darmstadt 8 Dec 11, 2022
CKD - Collaborative Knowledge Distillation for Heterogeneous Information Network Embedding

Collaborative Knowledge Distillation for Heterogeneous Information Network Embed

zhousheng 9 Dec 05, 2022
Official implementation for "Symbolic Learning to Optimize: Towards Interpretability and Scalability"

Symbolic Learning to Optimize This is the official implementation for ICLR-2022 paper "Symbolic Learning to Optimize: Towards Interpretability and Sca

VITA 8 Dec 19, 2022
A tutorial on DataFrames.jl prepared for JuliaCon2021

JuliaCon2021 DataFrames.jl Tutorial This is a tutorial on DataFrames.jl prepared for JuliaCon2021. A video recording of the tutorial is available here

Bogumił Kamiński 106 Jan 09, 2023
Universal Probability Distributions with Optimal Transport and Convex Optimization

Sylvester normalizing flows for variational inference Pytorch implementation of Sylvester normalizing flows, based on our paper: Sylvester normalizing

Rianne van den Berg 172 Dec 13, 2022
Luminous is a framework for testing the performance of Embodied AI (EAI) models in indoor tasks.

Luminous is a framework for testing the performance of Embodied AI (EAI) models in indoor tasks. Generally, we intergrete different kind of functional

28 Jan 08, 2023
SCU OlympicsRunning Baseline

Competition 1v1 running Environment check details in Jidi Competition RLChina2021智能体竞赛 做出的修改: 奖励重塑:修改了环境,重新设置了奖励的分配,使得奖励组成不只有零和博弈,还有探索环境的奖励。 算法微调:修改了官

ZiSeoi Wong 2 Nov 23, 2021
3D HourGlass Networks for Human Pose Estimation Through Videos

3D-HourGlass-Network 3D CNN Based Hourglass Network for Human Pose Estimation (3D Human Pose) from videos. This was my summer'18 research project. Dis

Naman Jain 51 Jan 02, 2023