Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Overview

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

License: MIT

Code for this paper Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly. [Preprint]

Tianlong Chen, Yu Cheng, Zhe Gan, Jingjing Liu, Zhangyang Wang.

Overview

Training generative adversarial networks (GANs) with limited data generally results in deteriorated performance and collapsed models. To conquerthis challenge, we are inspired by the latest observation of Kalibhat et al. (2020); Chen et al.(2021d), that one can discover independently trainable and highly sparse subnetworks (a.k.a.,lottery tickets) from GANs. Treating this as aninductive prior, we decompose the data-hungry GAN training into two sequential sub-problems:

  • (i) identifying the lottery ticket from the original GAN;
  • (ii) then training the found sparse subnetwork with aggressive data and feature augmentations.

Both sub-problems re-use the same small training set of real images. Such a coordinated framework enables us to focus on lower-complexity and more data-efficient sub-problems, effectively stabilizing trainingand improving convergence.

Methodology

Experiment Results

More experiments can be found in our paper.

Implementation

For the first step, finding the lottery tickets in GAN is referred to this repo.

For the second step, training GAN ticket toughly are provides as follow:

Environment for SNGAN

conda install python3.6
conda install pytorch1.4.0 -c pytorch
pip install tensorflow-gpu==1.13
pip install imageio
pip install tensorboardx

R.K. Donwload fid statistics from Fid_Stat.

Commands for SNGAN

R.K. Limited data training for SNGAN

  • Dataset: CIFAR-10

Example for full model training on 20% limited data (--ratio 0.2):

python train_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --ratio 0.2

Example for full model training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2

Example for sparse model (i.e., GAN tickets) training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_with_masks_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2 --rewind-path <>
  • --rewind-path: the stored path of identified sparse masks

Environment for BigGAN

conda env create -f environment.yml studiogan

Commands for BigGAN

R.K. Limited data training for BigGAN

  • Dataset: TINY ILSVRC

Example:

python main_ompg.py -t -e -c ./configs/TINY_ILSVRC2012/BigGAN_adv.json --eval_type valid --seed 42 --mask_path checkpoints/BigGAN-train-0.1 --mask_round 2 --reduce_train_dataset 0.1 --gamma 0.01 
  • --mask_path: the stored path of identified sparse masks
  • --mask_round: the sparsity level = 0.8 ^ mask_round
  • --reduce_train_dataset: the size of used limited training data
  • --gamma: hyperparameter for AdvAug. You can set it to 0 to git rid of AdvAug

  • Dataset: CIFAR100

Example:

python main_ompg.py -t -e -c ./configs/CIFAR100_less/DiffAugGAN_adv.json --ratio 0.2 --mask_path checkpoints/diffauggan_cifar100_0.2 --mask_round 9 --seed 42 --gamma 0.01
  • DiffAugGAN_adv.json: it indicate this confirguration use DiffAug.

Pre-trained Models

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights

https://www.dropbox.com/sh/7v8hn2859cvm7jj/AACyN8FOkMjgMwy5ibVj61IPa?dl=0

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/gsklrdcjzogqzcd/AAALlIYcWOZuERLcocKIqlEya?dl=0

  • BigGAN / CIFAR-10 / 10% Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/epuajb1iqn5xma6/AAAD0zwehky1wvV3M3-uesHsa?dl=0

  • BigGAN / CIFAR-100 10% / Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/y3pqdqee39jpct4/AAAsSebqHwkWmjO_O8Hp0hcEa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model

https://www.dropbox.com/sh/2rmvqwgcjir1p2l/AABNEo0B-0V9ZSnLnKF_OUA3a?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model + AdvAug on G and D

https://www.dropbox.com/sh/pbwjphualzdy2oe/AACZ7VYJctNBKz3E9b8fgj_Ia?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights

https://www.dropbox.com/sh/82i9z44uuczj3u3/AAARsfNzOgd1R9sKuh1OqUdoa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/yknk1joigx0ufbo/AAChMvzCsedejFjY1XxGcaUta?dl=0

Citation

@misc{chen2021ultradataefficient,
      title={Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly}, 
      author={Tianlong Chen and Yu Cheng and Zhe Gan and Jingjing Liu and Zhangyang Wang},
      year={2021},
      eprint={2103.00397},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Acknowledgement

https://github.com/VITA-Group/GAN-LTH

https://github.com/GongXinyuu/sngan.pytorch

https://github.com/VITA-Group/AutoGAN

https://github.com/POSTECH-CVLab/PyTorch-StudioGAN

https://github.com/mit-han-lab/data-efficient-gans

https://github.com/lucidrains/stylegan2-pytorch

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Neurolab is a simple and powerful Neural Network Library for Python

Neurolab Neurolab is a simple and powerful Neural Network Library for Python. Contains based neural networks, train algorithms and flexible framework

152 Dec 06, 2022
CvT2DistilGPT2 is an encoder-to-decoder model that was developed for chest X-ray report generation.

CvT2DistilGPT2 Improving Chest X-Ray Report Generation by Leveraging Warm-Starting This repository houses the implementation of CvT2DistilGPT2 from [1

The Australian e-Health Research Centre 21 Dec 28, 2022
SMPLpix: Neural Avatars from 3D Human Models

subject0_validation_poses.mp4 Left: SMPL-X human mesh registered with SMPLify-X, middle: SMPLpix render, right: ground truth video. SMPLpix: Neural Av

Sergey Prokudin 292 Dec 30, 2022
some academic posters as references. May we have in-person poster session soon!

some academic posters as references. May we have in-person poster session soon!

Bolei Zhou 472 Jan 06, 2023
Deep Unsupervised 3D SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment.

(ACMMM 2021 Oral) SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment This repository shows two tasks: Face landmark detection and Fac

BoomStar 51 Dec 13, 2022
Code release for "Masked-attention Mask Transformer for Universal Image Segmentation"

Mask2Former: Masked-attention Mask Transformer for Universal Image Segmentation Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Ro

Meta Research 1.2k Jan 02, 2023
A reimplementation of DCGAN in PyTorch

DCGAN in PyTorch A reimplementation of DCGAN in PyTorch. Although there is an abundant source of code and examples found online (as well as an officia

Diego Porres 6 Jan 08, 2022
PyTorch Implementation for "ForkGAN with SIngle Rainy NIght Images: Leveraging the RumiGAN to See into the Rainy Night"

ForkGAN with Single Rainy Night Images: Leveraging the RumiGAN to See into the Rainy Night By Seri Lee, Department of Engineering, Seoul National Univ

Seri Lee 52 Oct 12, 2022
Reproducing Results from A Hybrid Approach to Targeting Social Assistance

title author date output Reproducing Results from A Hybrid Approach to Targeting Social Assistance Lendie Follett and Heath Henderson 12/28/2021 html_

Lendie Follett 0 Jan 06, 2022
Semi-Autoregressive Transformer for Image Captioning

Semi-Autoregressive Transformer for Image Captioning Requirements Python 3.6 Pytorch 1.6 Prepare data Please use git clone --recurse-submodules to clo

YE Zhou 23 Dec 09, 2022
Energy consumption estimation utilities for Jetson-based platforms

This repository contains a utility for measuring energy consumption when running various programs in NVIDIA Jetson-based platforms. Currently TX-2, NX, and AGX are supported.

OpenDR 10 Jun 17, 2022
phylotorch-bito is a package providing an interface to BITO for phylotorch

phylotorch-bito phylotorch-bito is a package providing an interface to BITO for phylotorch Dependencies phylotorch BITO Installation Get the source co

Mathieu Fourment 2 Sep 01, 2022
Video-Captioning - A machine Learning project to generate captions for video frames indicating the relationship between the objects in the video

Video-Captioning - A machine Learning project to generate captions for video frames indicating the relationship between the objects in the video

1 Jan 23, 2022
Code for intrusion detection system (IDS) development using CNN models and transfer learning

Intrusion-Detection-System-Using-CNN-and-Transfer-Learning This is the code for the paper entitled "A Transfer Learning and Optimized CNN Based Intrus

Western OC2 Lab 38 Dec 12, 2022
An unofficial implementation of "Unpaired Image Super-Resolution using Pseudo-Supervision." CVPR2020

UnpairedSR An unofficial implementation of "Unpaired Image Super-Resolution using Pseudo-Supervision." CVPR2020 turn RCAN(modified) -- xmodel(xilinx

JiaKui Hu 10 Oct 28, 2022
Multiple-Object Tracking with Transformer

TransTrack: Multiple-Object Tracking with Transformer Introduction TransTrack: Multiple-Object Tracking with Transformer Models Training data Training

Peize Sun 537 Jan 04, 2023
DLFlow is a deep learning framework.

DLFlow是一套深度学习pipeline,它结合了Spark的大规模特征处理能力和Tensorflow模型构建能力。利用DLFlow可以快速处理原始特征、训练模型并进行大规模分布式预测,十分适合离线环境下的生产任务。利用DLFlow,用户只需专注于模型开发,而无需关心原始特征处理、pipeline构建、生产部署等工作。

DiDi 152 Oct 27, 2022
ICLR2021 (Under Review)

Self-Supervised Time Series Representation Learning by Inter-Intra Relational Reasoning This repository contains the official PyTorch implementation o

Haoyi Fan 58 Dec 30, 2022
Compare neural networks by their feature similarity

PyTorch Model Compare A tiny package to compare two neural networks in PyTorch. There are many ways to compare two neural networks, but one robust and

Anand Krishnamoorthy 181 Jan 04, 2023
Model Zoo for MindSpore

Welcome to the Model Zoo for MindSpore In order to facilitate developers to enjoy the benefits of MindSpore framework, we will continue to add typical

MindSpore 226 Jan 07, 2023