Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs)

Overview

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements

[paper (NeurIPS 2021)] [paper (arXiv)] [code]

Authors: Zinan Lin, Vyas Sekar, Giulia Fanti

Abstract: Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs). However, there is currently limited understanding of why SN is effective. In this work, we show that SN controls two important failure modes of GAN training: exploding and vanishing gradients. Our proofs illustrate a (perhaps unintentional) connection with the successful LeCun initialization. This connection helps to explain why the most popular implementation of SN for GANs requires no hyper-parameter tuning, whereas stricter implementations of SN have poor empirical performance out-of-the-box. Unlike LeCun initialization which only controls gradient vanishing at the beginning of training, SN preserves this property throughout training. Building on this theoretical understanding, we propose a new spectral normalization technique: Bidirectional Scaled Spectral Normalization (BSSN), which incorporates insights from later improvements to LeCun initialization: Xavier initialization and Kaiming initialization. Theoretically, we show that BSSN gives better gradient control than SN. Empirically, we demonstrate that it outperforms SN in sample quality and training stability on several benchmark datasets.


This repo contains the codes for reproducing the experiments of our BSN and different SN variants in the paper. The codes were tested under Python 2.7.5, TensorFlow 1.14.0.

Preparing datasets

CIFAR10

Download cifar-10-python.tar.gz from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz (or from other sources).

STL10

Download stl10_binary.tar.gz from http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz (or from other sources), and put it in dataset_preprocess/STL10 folder. Then run python preprocess.py. This code will resize the images into 48x48x3 format, and save the images in stl10.npy.

CelebA

Download img_align_celeba.zip from https://www.kaggle.com/jessicali9530/celeba-dataset (or from other sources), and put it in dataset_preprocess/CelebA folder. Then run python preprocess.py. This code will crop and resize the images into 64x64x3 format, and save the images in celeba.npy.

ImageNet

Download ILSVRC2012_img_train.tar from http://www.image-net.org/ (or from other sources), and put it in dataset_preprocess/ImageNet folder. Then run python preprocess.py. This code will crop and resize the images into 128x128x3 format, and save the images in ILSVRC2012folder. Each subfolder in ILSVRC2012 folder corresponds to one class. Each npy file in the subfolders corresponds to an image.

Training BSN and SN variants

Prerequisites

The codes are based on GPUTaskScheduler library, which helps you automatically schedule the jobs among GPU nodes. Please install it first. You may need to change GPU configurations according to the devices you have. The configurations are set in config.py in each directory. Please refer to GPUTaskScheduler's GitHub page for the details of how to make proper configurations.

You can also run these codes without GPUTaskScheduler. Just run python gan.py in gan subfolders.

CIFAR10, STL10, CelebA

Preparation

Copy the preprocessed datasets from the previous steps into the following paths:

  • CIFAR10: /data/CIFAR10/cifar-10-python.tar.gz.
  • STL10: /data/STL10/cifar-10-stl10.npy.
  • CelebA: /data/CelebA/celeba.npy.

Here means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-CNN.
  • SN with the same gammas: same_gamma-CNN.
  • SN with different gammas: diff_gamma-CNN.

Alternatively, you can directly modify the dataset paths in /gan_task.py to the path of the preprocessed dataset folders.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

ImageNet

Preparation

Copy the preprocessed folder ILSVRC2012 from the previous steps to /data/imagenet/ILSVRC2012, where means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-ResNet.

Alternatively, you can directly modify the dataset path in /gan_task.py to the path of the preprocessed folder ILSVRC2012.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

The code supports multi-GPU training for speed-up, by separating each data batch equally among multiple GPUs. To do that, you only need to make minor modifications in config.py. For example, if you have two GPUs with IDs 0 and 1, then all you need to do is to (1) change "gpu": ["0"] to "gpu": [["0", "1"]], and (2) change "num_gpus": [1] to "num_gpus": [2]. Note that the number of GPUs might influence the results because in this implementation the batch normalization layers on different GPUs are independent. In our experiments, we were using only one GPU.

Results

The code generates the following result files/folders:

  • /results/ /worker.log : Standard output and error from the code.
  • /results/ /metrics.csv : Inception Score and FID during training.
  • /results/ /sample/*.png : Generated images during training.
  • /results/ /checkpoint/* : TensorFlow checkpoints.
  • /results/ /time.txt : Training iteration timestamps.
Owner
Zinan Lin
Ph.D. student at Electrical and Computer Engineering, Carnegie Mellon University
Zinan Lin
Tianshou - An elegant PyTorch deep reinforcement learning library.

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on

Tsinghua Machine Learning Group 5.5k Jan 05, 2023
Benchmarks for Object Detection in Aerial Images

Benchmarks for Object Detection in Aerial Images

Jian Ding 691 Dec 30, 2022
SHRIMP: Sparser Random Feature Models via Iterative Magnitude Pruning

SHRIMP: Sparser Random Feature Models via Iterative Magnitude Pruning This repository is the official implementation of "SHRIMP: Sparser Random Featur

Bobby Shi 0 Dec 16, 2021
🥇Samsung AI Challenge 2021 1등 솔루션입니다🥇

MoT - Molecular Transformer Large-scale Pretraining for Molecular Property Prediction Samsung AI Challenge for Scientific Discovery This repository is

Jungwoo Park 44 Dec 03, 2022
Iran Open Source Hackathon

Iran Open Source Hackathon is an open-source hackathon (duh) with the aim of encouraging participation in open-source contribution amongst Iranian dev

OSS Hackathon 121 Dec 25, 2022
dualPC.R contains the R code for the main functions.

dualPC.R contains the R code for the main functions. dualPC_sim.R contains an example run with the different PC versions; it calls dualPC_algs.R whic

3 May 30, 2022
Bio-OFC gym implementation and Gym-Fly environment

Bio-OFC gym implementation and Gym-Fly environment This repository includes the gym compatible implementation of the Bio-OFC algorithm from the paper

Siavash Golkar 1 Nov 16, 2021
Research code for the paper "Variational Gibbs inference for statistical estimation from incomplete data".

Variational Gibbs inference (VGI) This repository contains the research code for Simkus, V., Rhodes, B., Gutmann, M. U., 2021. Variational Gibbs infer

Vaidotas Šimkus 1 Apr 08, 2022
Code release for BlockGAN: Learning 3D Object-aware Scene Representations from Unlabelled Images

BlockGAN Code release for BlockGAN: Learning 3D Object-aware Scene Representations from Unlabelled Images BlockGAN: Learning 3D Object-aware Scene Rep

41 May 18, 2022
A simple pytorch pipeline for semantic segmentation.

SegmentationPipeline -- Pytorch A simple pytorch pipeline for semantic segmentation. Requirements : torch=1.9.0 tqdm albumentations=1.0.3 opencv-pyt

petite7 4 Feb 22, 2022
Reproduce results and replicate training fo T0 (Multitask Prompted Training Enables Zero-Shot Task Generalization)

T-Zero This repository serves primarily as codebase and instructions for training, evaluation and inference of T0. T0 is the model developed in Multit

BigScience Workshop 253 Dec 27, 2022
Playable Video Generation

Playable Video Generation Playable Video Generation Willi Menapace, Stéphane Lathuilière, Sergey Tulyakov, Aliaksandr Siarohin, Elisa Ricci Paper: ArX

Willi Menapace 136 Dec 31, 2022
Generic Event Boundary Detection: A Benchmark for Event Segmentation

Generic Event Boundary Detection: A Benchmark for Event Segmentation We release our data annotation & baseline codes for detecting generic event bound

47 Nov 22, 2022
Official Implementation for Fast Training of Neural Lumigraph Representations using Meta Learning.

Fast Training of Neural Lumigraph Representations using Meta Learning Project Page | Paper | Data Alexander W. Bergman, Petr Kellnhofer, Gordon Wetzst

Alex 39 Oct 08, 2022
A Flow-based Generative Network for Speech Synthesis

WaveGlow: a Flow-based Generative Network for Speech Synthesis Ryan Prenger, Rafael Valle, and Bryan Catanzaro In our recent paper, we propose WaveGlo

NVIDIA Corporation 2k Dec 26, 2022
Implementing a simplified copy of Shazam application from scratch using MinHashing and LSH.

Building Shazam from scratch In this repository we tried to implement a simplified copy of the Shazam application able to tell you the name of a song

Arturo Ghinassi 0 Nov 17, 2022
[ACL-IJCNLP 2021] "EarlyBERT: Efficient BERT Training via Early-bird Lottery Tickets"

EarlyBERT This is the official implementation for the paper in ACL-IJCNLP 2021 "EarlyBERT: Efficient BERT Training via Early-bird Lottery Tickets" by

VITA 13 May 11, 2022
PyTorch implementation of the paper: "Preference-Adaptive Meta-Learning for Cold-Start Recommendation", IJCAI, 2021.

PAML PyTorch implementation of the paper: "Preference-Adaptive Meta-Learning for Cold-Start Recommendation", IJCAI, 2021. (Continuously updating ) Int

15 Nov 18, 2022
modelvshuman is a Python library to benchmark the gap between human and machine vision

modelvshuman is a Python library to benchmark the gap between human and machine vision. Using this library, both PyTorch and TensorFlow models can be evaluated on 17 out-of-distribution datasets with

Bethge Lab 244 Jan 03, 2023
[BMVC 2021] Official PyTorch Implementation of Self-supervised learning of Image Scale and Orientation Estimation

Self-Supervised Learning of Image Scale and Orientation Estimation (BMVC 2021) This is the official implementation of the paper "Self-Supervised Learn

Jongmin Lee 17 Nov 10, 2022