Bayesian Generative Adversarial Networks in Tensorflow

Related tags

Deep Learningbayesgan
Overview

Bayesian Generative Adversarial Networks in Tensorflow

This repository contains the Tensorflow implementation of the Bayesian GAN by Yunus Saatchi and Andrew Gordon Wilson. This paper appears at NIPS 2017.

Please cite our paper if you find this code useful in your research. The bibliographic information for the paper is

@inproceedings{saatciwilson,
  title={Bayesian gan},
  author={Saatci, Yunus and Wilson, Andrew G},
  booktitle={Advances in neural information processing systems},
  pages={3622--3631},
  year={2017}
}

Contents

  1. Introduction
  2. Dependencies
  3. Training options
  4. Usage
    1. Installation
    2. Synthetic Data
    3. Examples: MNIST, CIFAR10, CelebA, SVHN
    4. Custom data

Introduction

In the Bayesian GAN we propose conditional posteriors for the generator and discriminator weights, and marginalize these posteriors through stochastic gradient Hamiltonian Monte Carlo. Key properties of the Bayesian approach to GANs include (1) accurate predictions on semi-supervised learning problems; (2) minimal intervention for good performance; (3) a probabilistic formulation for inference in response to adversarial feedback; (4) avoidance of mode collapse; and (5) a representation of multiple complementary generative and discriminative models for data, forming a probabilistic ensemble.

We illustrate a multimodal posterior over the parameters of the generator. Each setting of these parameters corresponds to a different generative hypothesis for the data. We show here samples generated for two different settings of this weight vector, corresponding to different writing styles. The Bayesian GAN retains this whole distribution over parameters. By contrast, a standard GAN represents this whole distribution with a point estimate (analogous to a single maximum likelihood solution), missing potentially compelling explanations for the data.

Dependencies

This code has the following dependencies (version number crucial):

  • python 2.7
  • tensorflow==1.0.0

To install tensorflow 1.0.0 on linux please follow instructions at https://www.tensorflow.org/versions/r1.0/install/.

  • scikit-learn==0.17.1

You can install scikit-learn 0.17.1 with the following command

pip install scikit-learn==0.17.1

Alternatively, you can create a conda environment and set it up using the provided environment.yml file, as such:

conda env create -f environment.yml -n bgan

then load the environment using

source activate bgan

Usage

Installation

  1. Install the required dependencies
  2. Clone this repository

Synthetic Data

To run the synthetic experiment from the paper you can use bgan_synth script. For example, the following comand will train the Bayesian GAN (with D=100 and d=10) for 5000 iterations and store the results in .

./bgan_synth.py --x_dim 100 --z_dim 10 --numz 10 --out 
   

   

To run the ML GAN for the same data run

./bgan_synth.py --x_dim 100 --z_dim 10 --numz 1 --out 
   

   

bgan_synth has --save_weights, --out_dir, --z_dim, --numz, --wasserstein, --train_iter and --x_dim parameters. x_dim contolls the dimensionality of the observed data (x in the paper). For description of other parameters please see Training options.

Once you run the above two commands you will see the output of each 100th iteration in . So, for example, the Bayesian GAN's output at the 900th iteration will look like:

In contrast, the output of the standard GAN (corresponding to numz=1, which forces ML estimation) will look like:

indicating clearly the tendency of mode collapse in the standard GAN which, for this synthetic example, is completely avoided by the Bayesian GAN.

To explore the sythetic experiment further, and to generate the Jensen-Shannon divergence plots, you can check out the notebook synth.ipynb.

Unsupervised and Semi-Supervised Learning on benchmark datasets

MNIST, CIFAR10, CelebA, SVHN

bayesian_gan_hmc script allows to train the model on standard and custom datasets. Below we describe the usage of this script.

Data preparation

To reproduce the experiments on MNIST, CIFAR10, CelebA and SVHN datasets you need to prepare the data and use a correct --data_path.

  • for MNIST you don't need to prepare the data and can provide any --data_path;
  • for CIFAR10 please download and extract the python version of the data from https://www.cs.toronto.edu/~kriz/cifar.html; then use the path to the directory containing cifar-10-batches-py as --data_path;
  • for SVHN please download train_32x32.mat and test_32x32.mat files from http://ufldl.stanford.edu/housenumbers/ and use the directory containing these files as your --data_path;
  • for CelebA you will need to have openCV installed. You can find the download links for the data at http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. You will need to create celebA folder with Anno and img_align_celeba subfolders. Anno must contain the list_attr_celeba.txt and img_align_celeba must contain the .jpg files. You will also need to crop the images by running datasets/crop_faces.py script with --data_path where is the path to the folder containing celebA. When training the model, you will need to use the same for --data_path;

Unsupervised training

You can run unsupervised learning by running the bayesian_gan_hmc script without --semi parameter. For example, use

./run_bgan.py --data_path 
   
     --dataset svhn --numz 10 --num_mcmc 2 --out_dir 

    
      --train_iter 75000 --save_samples --n_save 100

    
   

to train the model on the SVHN dataset. This command will run the method for 75000 iterations and save samples every 100 iterations. Here must lead to the directory where the results will be stored. See data preparation section for an explanation of how to set . See training options section for a description of other training options.

         

Semi-supervised training

To run the semi-supervised experiments you can use the run_bgan_semi.py script, which offers many options including the following:

  • --out_dir: path to the folder, where the outputs will be stored
  • --n_save: samples and weights are saved every n_save iterations; default 100
  • --z_dim: dimensionalit of z vector for generator; default 100
  • --data_path: path to the data; see data preparation for a detailed discussion; this parameter is required
  • --dataset: can be mnist, cifar, svhn or celeb; default mnist
  • --batch_size: batch size for training; default 64
  • --prior_std: std of the prior distribution over the weights; default 1
  • --num_gen: same as J in the paper; number of samples of z to integrate it out for generators; default 1
  • --num_disc: same as J_D in the paper; number of samples of z to integrate it out for discriminators; default 1
  • --num_mcmc: same as M in the paper; number of MCMC NN weight samples per z; default 1
  • --lr: learning rate used by the Adam optimizer; default 0.0002
  • --optimizer: optimization method to be used: adam (tf.train.AdamOptimizer) or sgd (tf.train.MomentumOptimizer); default adam
  • --N: number of labeled samples for semi-supervised learning
  • --train_iter: number of training iterations; default 50000
  • --save_samples: save generated samples during training
  • --save_weights: save weights during training
  • --random_seed: random seed; note that setting this seed does not lead to 100% reproducible results if GPU is used

You can also run WGANs with --wasserstein or train an ensemble of DCGANs with --ml_ensemble . In particular you can train a DCGAN with --ml.

You can train the model in semi-supervised setting by running bayesian_gan_hmc with --semi option. Use -N parameter to set the number of labeled examples to train on. For example, use

./run_bgan_semi.py --data_path 
   
     --dataset cifar --num_gen 10 --num_mcmc 2
--out_dir 
    
      --train_iter 100000 --N 4000 --lr 0.0005

    
   

to train the model on CIFAR10 dataset with 4000 labeled examples. This command will train the model for 100000 iterations and store the outputs in folder.

To train the model on MNIST with 100 labeled examples you can use the following command.

./bayesian_gan_hmc.py --data_path 
   
    / --dataset mnist --num_gen 10 --num_mcmc 2
--out_dir 
    
      --train_iter 100000 -N 100 --semi --lr 0.0005

    
   

Custom data

To train the model on a custom dataset you need to define a class with a specific interface. Suppose we want to train the model on the digits dataset. This datasets consists of 8x8 images of digits. Let's suppose that the data is stored in x_tr.npy, y_tr.npy, x_te.npy and y_te.npy files. We will assume that x_tr.npy and x_te.npy have shapes of the form (?, 8, 8, 1). We can then define the class corresponding to this dataset in bgan_util.py as follows.

class Digits:

    def __init__(self):
        self.imgs = np.load('x_tr.npy') 
        self.test_imgs = np.load('x_te.npy')
        self.labels = np.load('y_tr.npy')
        self.test_labels = np.load('y_te.npy')
        self.labels = one_hot_encoded(self.labels, 10)
        self.test_labels = one_hot_encoded(self.test_labels, 10) 
        self.x_dim = [8, 8, 1]
        self.num_classes = 10

    @staticmethod
    def get_batch(batch_size, x, y): 
        """Returns a batch from the given arrays.
        """
        idx = np.random.choice(range(x.shape[0]), size=(batch_size,), replace=False)
        return x[idx], y[idx]

    def next_batch(self, batch_size, class_id=None):
        return self.get_batch(batch_size, self.imgs, self.labels)

    def test_batch(self, batch_size):
        return self.get_batch(batch_size, self.test_imgs, self.test_labels)

The class must have next_batch and test_batch, and must have the imgs, labels, test_imgs, test_labels, x_dim and num_classes fields.

Now we can import the Digits class in bayesian_gan_hmc.py

from bgan_util import Digits

and add the following lines to to the processing of --dataset parameter.

if args.dataset == "digits":
    dataset = Digits()

After this preparation is done, we can train the model with, for example,

./run_bgan_semi.py --data_path 
   
     --dataset digits --num_gen 10 --num_mcmc 2 
--out_dir 
    
      --train_iter 100000 --save_samples

    
   

Acknowledgements

We thank Pavel Izmailov and Ben Athiwaratkun for help with stress testing this code and creating the tutorial.

Owner
Andrew Gordon Wilson
Machine Learning Professor at New York University.
Andrew Gordon Wilson
Sionna: An Open-Source Library for Next-Generation Physical Layer Research

Sionna: An Open-Source Library for Next-Generation Physical Layer Research Sionna™ is an open-source Python library for link-level simulations of digi

NVIDIA Research Projects 313 Dec 22, 2022
Feup-csr - Repository holding my group's submission to the CSR project competition

CSR Competições de Swarm Robotics Swarm Robotics Competitions This repository holds the files submitted for the CSR project competition. Project group

Nuno Pereira 1 Jan 04, 2022
Discerning Decision-Making Process of Deep Neural Networks with Hierarchical Voting Transformation

Configurations Change HOME_PATH in CONFIG.py as the current path Data Prepare CENSINCOME Download data Put census-income.data and census-income.test i

2 Aug 14, 2022
the code for paper "Energy-Based Open-World Uncertainty Modeling for Confidence Calibration"

EOW-Softmax This code is for the paper "Energy-Based Open-World Uncertainty Modeling for Confidence Calibration". Accepted by ICCV21. Usage Commnd exa

Yezhen Wang 36 Dec 02, 2022
TorchX is a library containing standard DSLs for authoring and running PyTorch related components for an E2E production ML pipeline.

TorchX is a library containing standard DSLs for authoring and running PyTorch related components for an E2E production ML pipeline

193 Dec 22, 2022
Galileo library for large scale graph training by JD

近年来,图计算在搜索、推荐和风控等场景中获得显著的效果,但也面临超大规模异构图训练,与现有的深度学习框架Tensorflow和PyTorch结合等难题。 Galileo(伽利略)是一个图深度学习框架,具备超大规模、易使用、易扩展、高性能、双后端等优点,旨在解决超大规模图算法在工业级场景的落地难题,提

JD Galileo Team 128 Nov 29, 2022
CAUSE: Causality from AttribUtions on Sequence of Events

CAUSE: Causality from AttribUtions on Sequence of Events

Wei Zhang 21 Dec 01, 2022
Convenient tool for speeding up the intern/officer review process.

icpc-app-screen Convenient tool for speeding up the intern/officer applicant review process. Eliminates the pain from reading application responses of

1 Oct 30, 2021
Gradient Step Denoiser for convergent Plug-and-Play

Source code for the paper "Gradient Step Denoiser for convergent Plug-and-Play"

Samuel Hurault 11 Sep 17, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Recognize Handwritten Digits using Deep Learning on the browser itself.

MNIST on the Web An attempt to predict MNIST handwritten digits from my PyTorch model from the browser (client-side) and not from the server, with the

Harjyot Bagga 7 May 28, 2022
Implementation of our paper "Video Playback Rate Perception for Self-supervised Spatio-Temporal Representation Learning".

PRP Introduction This is the implementation of our paper "Video Playback Rate Perception for Self-supervised Spatio-Temporal Representation Learning".

yuanyao366 39 Dec 29, 2022
A collection of Reinforcement Learning algorithms from Sutton and Barto's book and other research papers implemented in Python.

Reinforcement-Learning-Notebooks A collection of Reinforcement Learning algorithms from Sutton and Barto's book and other research papers implemented

Pulkit Khandelwal 1k Dec 28, 2022
This script scrapes and stores the availability of timeslots for Car Driving Test at all RTA Serivce NSW centres in the state.

This script scrapes and stores the availability of timeslots for Car Driving Test at all RTA Serivce NSW centres in the state. Dependencies Account wi

Balamurugan Soundararaj 21 Dec 14, 2022
DetCo: Unsupervised Contrastive Learning for Object Detection

DetCo: Unsupervised Contrastive Learning for Object Detection arxiv link News Sparse RCNN+DetCo improves from 45.0 AP to 46.5 AP(+1.5) with 3x+ms trai

Enze Xie 234 Dec 18, 2022
An implementation of the [Hierarchical (Sig-Wasserstein) GAN] algorithm for large dimensional Time Series Generation

Hierarchical GAN for large dimensional financial market data Implementation This repository is an implementation of the [Hierarchical (Sig-Wasserstein

11 Nov 29, 2022
Another pytorch implementation of FCN (Fully Convolutional Networks)

FCN-pytorch-easiest Trying to be the easiest FCN pytorch implementation and just in a get and use fashion Here I use a handbag semantic segmentation f

Y. Dong 158 Dec 21, 2022
Codes for CVPR2021 paper "PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization"

PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization (CVPR 2021) This is the official implementation of PW

Intelligent Robotics and Machine Vision Lab 42 Dec 18, 2022
[AAAI-2022] Official implementations of MCL: Mutual Contrastive Learning for Visual Representation Learning

Mutual Contrastive Learning for Visual Representation Learning This project provides source code for our Mutual Contrastive Learning for Visual Repres

winycg 48 Jan 02, 2023
Keras udrl - Keras implementation of Upside Down Reinforcement Learning

keras_udrl Keras implementation of Upside Down Reinforcement Learning This is me

Eder Santana 7 Jan 24, 2022