Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Overview

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

License: MIT

Code for this paper Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly. [Preprint]

Tianlong Chen, Yu Cheng, Zhe Gan, Jingjing Liu, Zhangyang Wang.

Overview

Training generative adversarial networks (GANs) with limited data generally results in deteriorated performance and collapsed models. To conquerthis challenge, we are inspired by the latest observation of Kalibhat et al. (2020); Chen et al.(2021d), that one can discover independently trainable and highly sparse subnetworks (a.k.a.,lottery tickets) from GANs. Treating this as aninductive prior, we decompose the data-hungry GAN training into two sequential sub-problems:

  • (i) identifying the lottery ticket from the original GAN;
  • (ii) then training the found sparse subnetwork with aggressive data and feature augmentations.

Both sub-problems re-use the same small training set of real images. Such a coordinated framework enables us to focus on lower-complexity and more data-efficient sub-problems, effectively stabilizing trainingand improving convergence.

Methodology

Experiment Results

More experiments can be found in our paper.

Implementation

For the first step, finding the lottery tickets in GAN is referred to this repo.

For the second step, training GAN ticket toughly are provides as follow:

Environment for SNGAN

conda install python3.6
conda install pytorch1.4.0 -c pytorch
pip install tensorflow-gpu==1.13
pip install imageio
pip install tensorboardx

R.K. Donwload fid statistics from Fid_Stat.

Commands for SNGAN

R.K. Limited data training for SNGAN

  • Dataset: CIFAR-10

Example for full model training on 20% limited data (--ratio 0.2):

python train_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --ratio 0.2

Example for full model training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2

Example for sparse model (i.e., GAN tickets) training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_with_masks_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2 --rewind-path <>
  • --rewind-path: the stored path of identified sparse masks

Environment for BigGAN

conda env create -f environment.yml studiogan

Commands for BigGAN

R.K. Limited data training for BigGAN

  • Dataset: TINY ILSVRC

Example:

python main_ompg.py -t -e -c ./configs/TINY_ILSVRC2012/BigGAN_adv.json --eval_type valid --seed 42 --mask_path checkpoints/BigGAN-train-0.1 --mask_round 2 --reduce_train_dataset 0.1 --gamma 0.01 
  • --mask_path: the stored path of identified sparse masks
  • --mask_round: the sparsity level = 0.8 ^ mask_round
  • --reduce_train_dataset: the size of used limited training data
  • --gamma: hyperparameter for AdvAug. You can set it to 0 to git rid of AdvAug

  • Dataset: CIFAR100

Example:

python main_ompg.py -t -e -c ./configs/CIFAR100_less/DiffAugGAN_adv.json --ratio 0.2 --mask_path checkpoints/diffauggan_cifar100_0.2 --mask_round 9 --seed 42 --gamma 0.01
  • DiffAugGAN_adv.json: it indicate this confirguration use DiffAug.

Pre-trained Models

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights

https://www.dropbox.com/sh/7v8hn2859cvm7jj/AACyN8FOkMjgMwy5ibVj61IPa?dl=0

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/gsklrdcjzogqzcd/AAALlIYcWOZuERLcocKIqlEya?dl=0

  • BigGAN / CIFAR-10 / 10% Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/epuajb1iqn5xma6/AAAD0zwehky1wvV3M3-uesHsa?dl=0

  • BigGAN / CIFAR-100 10% / Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/y3pqdqee39jpct4/AAAsSebqHwkWmjO_O8Hp0hcEa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model

https://www.dropbox.com/sh/2rmvqwgcjir1p2l/AABNEo0B-0V9ZSnLnKF_OUA3a?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model + AdvAug on G and D

https://www.dropbox.com/sh/pbwjphualzdy2oe/AACZ7VYJctNBKz3E9b8fgj_Ia?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights

https://www.dropbox.com/sh/82i9z44uuczj3u3/AAARsfNzOgd1R9sKuh1OqUdoa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/yknk1joigx0ufbo/AAChMvzCsedejFjY1XxGcaUta?dl=0

Citation

@misc{chen2021ultradataefficient,
      title={Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly}, 
      author={Tianlong Chen and Yu Cheng and Zhe Gan and Jingjing Liu and Zhangyang Wang},
      year={2021},
      eprint={2103.00397},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Acknowledgement

https://github.com/VITA-Group/GAN-LTH

https://github.com/GongXinyuu/sngan.pytorch

https://github.com/VITA-Group/AutoGAN

https://github.com/POSTECH-CVLab/PyTorch-StudioGAN

https://github.com/mit-han-lab/data-efficient-gans

https://github.com/lucidrains/stylegan2-pytorch

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Camera calibration & 3D pose estimation tools for AcinoSet

AcinoSet: A 3D Pose Estimation Dataset and Baseline Models for Cheetahs in the Wild Daniel Joska, Liam Clark, Naoya Muramatsu, Ricardo Jericevich, Fre

African Robotics Unit 42 Nov 16, 2022
Datasets for new state-of-the-art challenge in disentanglement learning

High resolution disentanglement datasets This repository contains the Falcor3D and Isaac3D datasets, which present a state-of-the-art challenge for co

NVIDIA Research Projects 37 May 26, 2022
Deep Learning Head Pose Estimation using PyTorch.

Hopenet is an accurate and easy to use head pose estimation network. Models have been trained on the 300W-LP dataset and have been tested on real data with good qualitative performance.

Nataniel Ruiz 1.3k Dec 26, 2022
PyTorch Implementation of Fully Convolutional Networks. (Training code to reproduce the original result is available.)

pytorch-fcn PyTorch implementation of Fully Convolutional Networks. Requirements pytorch = 0.2.0 torchvision = 0.1.8 fcn = 6.1.5 Pillow scipy tqdm

Kentaro Wada 1.6k Jan 07, 2023
🤖 Project template for your next awesome AI project. 🦾

🤖 AI Awesome Project Template 👋 Template author You may want to adjust badge links in a README.md file. 💎 Installation with pip Installation is as

Wiktor Łazarski 18 Nov 23, 2022
Simply enable or disable your Nvidia dGPU

EnvyControl (WIP) Simply enable or disable your Nvidia dGPU Usage First clone this repo and install envycontrol with sudo pip install . CLI Turn off y

Victor Bayas 292 Jan 03, 2023
Repo público onde postarei meus estudos de Python, buscando aprender por meio do compartilhamento do aprendizado!

Seja bem vindo à minha repo de Estudos em Python 3! Este é um repositório criado por um programador amador que estuda tópicos de finanças, estatística

32 Dec 24, 2022
Single-stage Keypoint-based Category-level Object Pose Estimation from an RGB Image

CenterPose Overview This repository is the official implementation of the paper "Single-stage Keypoint-based Category-level Object Pose Estimation fro

NVIDIA Research Projects 188 Dec 27, 2022
Le dataset des images du projet d'IA de 2021

face-mask-dataset-ilc-2021 Le dataset des images du projet d'IA de 2021, Indiquez vos id git dans la issue pour les droits TL;DR: Choisir 200 images J

7 Nov 15, 2021
Learning with Noisy Labels via Sparse Regularization, ICCV2021

Learning with Noisy Labels via Sparse Regularization This repository is the official implementation of [Learning with Noisy Labels via Sparse Regulari

Xiong Zhou 38 Oct 20, 2022
This project uses Template Matching technique for object detecting by detection of template image over base image.

Object Detection Project Using OpenCV This project uses Template Matching technique for object detecting by detection the template image over base ima

Pratham Bhatnagar 7 May 29, 2022
OpenMMLab Computer Vision Foundation

English | 简体中文 Introduction MMCV is a foundational library for computer vision research and supports many research projects as below: MMCV: OpenMMLab

OpenMMLab 4.6k Jan 09, 2023
Code for PhySG: Inverse Rendering with Spherical Gaussians for Physics-based Relighting and Material Editing

PhySG: Inverse Rendering with Spherical Gaussians for Physics-based Relighting and Material Editing CVPR 2021. Project page: https://kai-46.github.io/

Kai Zhang 141 Dec 14, 2022
A minimalist environment for decision-making in autonomous driving

highway-env A collection of environments for autonomous driving and tactical decision-making tasks An episode of one of the environments available in

Edouard Leurent 1.6k Jan 07, 2023
A configurable, tunable, and reproducible library for CTR prediction

FuxiCTR This repo is the community dev version of the official release at huawei-noah/benchmark/FuxiCTR. Click-through rate (CTR) prediction is an cri

XUEPAI 397 Dec 30, 2022
This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021.

inverse_attention This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021. Le

Firas Laakom 5 Jul 08, 2022
This is the implementation of the paper "Self-supervised Outdoor Scene Relighting"

Self-supervised Outdoor Scene Relighting This is the implementation of the paper "Self-supervised Outdoor Scene Relighting". The model is implemented

Ye Yu 24 Dec 17, 2022
Tree LSTM implementation in PyTorch

Tree-Structured Long Short-Term Memory Networks This is a PyTorch implementation of Tree-LSTM as described in the paper Improved Semantic Representati

Riddhiman Dasgupta 529 Dec 10, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 08, 2022