The Official PyTorch Implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 spotlight paper)

Related tags

Deep LearningVAEBM
Overview

Official PyTorch implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 Spotlight Paper)

Zhisheng Xiao·Karsten Kreis·Jan Kautz·Arash Vahdat


VAEBM trains an energy network to refine the data distribution learned by an NVAE, where the enery network and the VAE jointly define an Energy-based model. The NVAE is pretrained before training the energy network, and please refer to NVAE's implementation for more details about constructing and training NVAE.

Set up datasets

We trained on several datasets, including CIFAR10, CelebA64, LSUN Church 64 and CelebA HQ 256. For large datasets, we store the data in LMDB datasets for I/O efficiency. Check here for information regarding dataset preparation.

Training NVAE

We use the following commands on each dataset for training the NVAE backbone. To train NVAEs, please use its original codebase with commands given here.

CIFAR-10 (8x 16-GB GPUs)

python train.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID --dataset cifar10 \
      --num_channels_enc 128 --num_channels_dec 128 --epochs 400 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 1 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --num_groups_per_scale 30 --batch_size 32 \
      --weight_decay_norm 1e-1 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

CelebA-64 (8x 16-GB GPUs)

python train.py --data  $DATA_DIR/celeba64_lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_64 \
      --num_channels_enc 48 --num_channels_dec 48 --epochs 50 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 5 \
      --batch_size 32 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --warmup_epochs 1 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

CelebA-HQ-256 (8x 32-GB GPUs)

python train.py -data  $DATA_DIR/celeba/celeba-lmdb --root $CHECKPOINT_DIR --save $EXPR_ID --dataset celeba_256 \
      --num_channels_enc 32 --num_channels_dec 32 --epochs 200 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_preprocess_blocks 1 \
      --num_postprocess_blocks 1 --weight_decay_norm 1e-2 --num_x_bits 5 --num_latent_scales 5 --num_groups_per_scale 4 \
      --num_nf 2 --batch_size 8 --fast_adamax  --num_mixture_dec 1 \
      --weight_decay_norm_anneal  --weight_decay_norm_init 1e1 --learning_rate 6e-3 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

LSUN Churches Outdoor 64 (8x 16-GB GPUs)

python train.py --data $DATA_DIR/LSUN/ --root $CHECKPOINT_DIR --save $EXPR_ID --dataset lsun_church_64 \
      --num_channels_enc 48 --num_channels_dec 48 --epochs 60 --num_postprocess_cells 2 --num_preprocess_cells 2 \
      --num_latent_scales 3 --num_latent_per_group 20 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 \
      --num_preprocess_blocks 1 --num_postprocess_blocks 1 --weight_decay_norm 1e-1 --num_groups_per_scale 5 \
      --batch_size 32 --num_nf 1 --num_mixture_dec 1 --fast_adamax  --warmup_epochs 1 --arch_instance res_mbconv \
      --num_process_per_node 8 --use_se --res_dist

Training VAEBM

We use the following commands on each dataset for training VAEBM. Note that you need to train the NVAE on corresponding dataset before running the training command here. After training the NVAE, pass the path of the checkpoint to the --checkpoint argument.

Note that the training of VAEBM will eventually explode (See Appendix E of our paper), and therefore it is important to save checkpoint regularly. After the training explodes, stop running the code and use the last few saved checkpoints for testing.

CIFAR-10

We train VAEBM on CIFAR-10 using one 32-GB V100 GPU.

python train_VAEBM.py  --checkpoint ./checkpoints/cifar10/checkpoint.pt --experiment cifar10_exp1
--dataset cifar10 --im_size 32 --data ./data/cifar10 --num_steps 10 
--wd 3e-5 --step_size 8e-5 --total_iter 30000 --alpha_s 0.2 --lr 4e-5 --max_p 0.6 
--anneal_step 5000. --batch_size 32 --n_channel 128

CelebA 64

We train VAEBM on CelebA 64 using one 32-GB V100 GPU.

python train_VAEBM.py --checkpoint ./checkpoints/celeba_64/checkpoint.pt --experiment celeba64_exp1 --dataset celeba_64 
--im_size 64 --lr 5e-5 --batch_size 32 --n_channel 64 --num_steps 10 --use_mu_cd --wd 3e-5 --step_size 5e-6 --total_iter 30000 
--alpha_s 0.2 

LSUN Church 64

We train VAEBM on LSUN Church 64 using one 32-GB V100 GPU.

python train_VAEBM.py --checkpoint ./checkpoints/lsun_church/checkpoint.pt --experiment lsunchurch_exp1 --dataset lsun_church 
--im_size 64 --batch_size 32 --n_channel 64 --num_steps 10 --use_mu_cd --wd 3e-5 --step_size 4e-6 --total_iter 30000 --alpha_s 0.2 --lr 4e-5 
--use_buffer --max_p 0.6 --anneal_step 5000

CelebA HQ 256

We train VAEBM on CelebA HQ 256 using four 32-GB V100 GPUs.

python train_VAEBM_distributed.py --checkpoint ./checkpoints/celeba_256/checkpoint.pt --experiment celeba256_exp1 --dataset celeba_256
--num_process_per_node 4 --im_size 256 --batch_size 4 --n_channel 64 --num_steps 6 --use_mu_cd --wd 3e-5 --step_size 3e-6 
--total_iter 9000 --alpha_s 0.3 --lr 4e-5 --use_buffer --max_p 0.6 --anneal_step 3000 --buffer_size 2000

Sampling from VAEBM

To generate samples from VAEBM after training, run sample_VAEBM.py, and it will generate 50000 test images in your given path. When sampling, we typically use longer Langvin dynamics than training for better sample quality, see Appendix E of the paper for the step sizes and number of steps we use to obtain test samples for each dataset. Other parameters that ensure successfully loading the VAE and energy network are the same as in the training codes.

For example, the script used to sample CIFAR-10 is

python sample_VAEBM.py --checkpoint ./checkpoints/cifar_10/checkpoint.pt --ebm_checkpoint ./saved_models/cifar_10/cifar_exp1/EBM.pth 
--dataset cifar10 --im_size 32 --batch_size 40 --n_channel 128 --num_steps 16 --step_size 8e-5 

For CelebA 64,

python sample_VAEBM.py --checkpoint ./checkpoints/celeba_64/checkpoint.pt --ebm_checkpoint ./saved_models/celeba_64/celeba64_exp1/EBM.pth 
--dataset celeba_64 --im_size 64 --batch_size 40 --n_channel 64 --num_steps 20 --step_size 5e-6 

For LSUN Church 64,

python sample_VAEBM.py --checkpoint ./checkpoints/lsun_church/checkpoint.pt --ebm_checkpoint ./saved_models/lsun_chruch/lsunchurch_exp1/EBM.pth 
--dataset lsun_church --im_size 64 --batch_size 40 --n_channel 64 --num_steps 20 --step_size 4e-6 

For CelebA HQ 256,

python sample_VAEBM.py --checkpoint ./checkpoints/celeba_256/checkpoint.pt --ebm_checkpoint ./saved_models/celeba_256/celeba256_exp1/EBM.pth 
--dataset celeba_256 --im_size 256 --batch_size 10 --n_channel 64 --num_steps 24 --step_size 3e-6 

Evaluation

After sampling, use the Tensorflow or PyTorch implementation to compute the FID scores. For example, when using the Tensorflow implementation, you can obtain the FID score by saving the training images in /path/to/training_images and running the script:

python fid.py /path/to/training_images /path/to/sampled_images

For CIFAR-10, the training statistics can be downloaded from here, and the FID score can be computed by running

python fid.py /path/to/sampled_images /path/to/precalculated_stats.npz

For the Inception Score, save samples in a single numpy array with pixel values in range [0, 255] and simply run

python ./thirdparty/inception_score.py --sample_dir /path/to/sampled_images

where the code for computing Inception Score is adapted from here.

License

Please check the LICENSE file. VAEBM may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].

Bibtex

Cite our paper using the following bibtex item:

@inproceedings{
xiao2021vaebm,
title={VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models},
author={Zhisheng Xiao and Karsten Kreis and Jan Kautz and Arash Vahdat},
booktitle={International Conference on Learning Representations},
year={2021}
}
FastCover: A Self-Supervised Learning Framework for Multi-Hop Influence Maximization in Social Networks by Anonymous.

FastCover: A Self-Supervised Learning Framework for Multi-Hop Influence Maximization in Social Networks by Anonymous.

0 Apr 02, 2021
SimplEx - Explaining Latent Representations with a Corpus of Examples

SimplEx - Explaining Latent Representations with a Corpus of Examples Code Author: Jonathan Crabbé ( Jonathan Crabbé 14 Dec 15, 2022

Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation (CoRL 2021)

Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation [Project website] [Paper] This project is a PyTorch i

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 6 Feb 28, 2022
Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training"

Saliency Guided Training Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training" by Aya Abdelsalam Ismail, Hector Cor

8 Sep 22, 2022
Recursive Bayesian Networks

Recursive Bayesian Networks This repository contains the code to reproduce the results from the NeurIPS 2021 paper Lieck R, Rohrmeier M (2021) Recursi

Robert Lieck 11 Oct 18, 2022
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Amin Rezaei 126 Dec 27, 2022
A curated list of the top 10 computer vision papers in 2021 with video demos, articles, code and paper reference.

The Top 10 Computer Vision Papers of 2021 The top 10 computer vision papers in 2021 with video demos, articles, code, and paper reference. While the w

Louis-François Bouchard 118 Dec 21, 2022
A simplistic and efficient pure-python neural network library from Phys Whiz with CPU and GPU support.

A simplistic and efficient pure-python neural network library from Phys Whiz with CPU and GPU support.

Manas Sharma 19 Feb 28, 2022
(CVPR2021) DANNet: A One-Stage Domain Adaptation Network for Unsupervised Nighttime Semantic Segmentation

DANNet: A One-Stage Domain Adaptation Network for Unsupervised Nighttime Semantic Segmentation CVPR2021(oral) [arxiv] Requirements python3.7 pytorch==

W-zx-Y 85 Dec 07, 2022
Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation.

Understanding Minimum Bayes Risk Decoding This repo provides code and documentation for the following paper: Müller and Sennrich (2021): Understanding

ZurichNLP 13 May 01, 2022
Amazing-Python-Scripts - 🚀 Curated collection of Amazing Python scripts from Basics to Advance with automation task scripts.

📑 Introduction A curated collection of Amazing Python scripts from Basics to Advance with automation task scripts. This is your Personal space to fin

Avinash Ranjan 1.1k Dec 29, 2022
GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles

GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles This repository contains a method to generate 3D conformer ensembles direct

127 Dec 20, 2022
The Instructed Glacier Model (IGM)

The Instructed Glacier Model (IGM) Overview The Instructed Glacier Model (IGM) simulates the ice dynamics, surface mass balance, and its coupling thro

27 Dec 16, 2022
Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

Kai Zhang 1.2k Dec 29, 2022
A tool for calculating distortion parameters in coordination complexes.

OctaDist Octahedral distortion calculator: A tool for calculating distortion parameters in coordination complexes. https://octadist.github.io/ Registe

OctaDist 12 Oct 04, 2022
Deep Unsupervised 3D SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment.

(ACMMM 2021 Oral) SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment This repository shows two tasks: Face landmark detection and Fac

BoomStar 51 Dec 13, 2022
Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

Sicheng 19 Dec 07, 2022
Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Jonas Köhler 893 Dec 28, 2022
Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR 2022)

Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR2022)[paper] Authors: Chenhang He, Ruihuang Li, Shuai Li, L

Billy HE 141 Dec 30, 2022
Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Ng Kam Woh 71 Dec 22, 2022