Ladder Variational Autoencoders (LVAE) in PyTorch

Overview

Ladder Variational Autoencoders (LVAE)

PyTorch implementation of Ladder Variational Autoencoders (LVAE) [1]:

                 LVAE equation

where the variational distributions q at each layer are multivariate Normal with diagonal covariance.

Significant differences from [1] include:

  • skip connections in the generative path: conditioning on all layers above rather than only on the layer above (see for example [2])
  • spatial (convolutional) latent variables
  • free bits [3] instead of beta annealing [4]

Install requirements and run MNIST example

pip install -r requirements.txt
CUDA_VISIBLE_DEVICES=0 python main.py --zdims 32 32 32 --downsample 1 1 1 --nonlin elu --skip --blocks-per-layer 4 --gated --freebits 0.5 --learn-top-prior --data-dep-init --seed 42 --dataset static_mnist

Dependencies include boilr (a framework for PyTorch) and multiobject (which provides multi-object datasets with PyTorch dataloaders).

Likelihood results

Log likelihood bounds on the test set (average over 4 random seeds).

dataset num layers -ELBO - log p(x)
[100 iws]
- log p(x)
[1000 iws]
binarized MNIST 3 82.14 79.47 79.24
binarized MNIST 6 80.74 78.65 78.52
binarized MNIST 12 80.50 78.50 78.30
multi-dSprites (0-2) 12 26.9 23.2
SVHN 15 4012 (1.88) 3973 (1.87)
CIFAR10 3 7651 (3.59) 7591 (3.56)
CIFAR10 6 7321 (3.44) 7268 (3.41)
CIFAR10 15 7128 (3.35) 7068 (3.32)
CelebA 20 20026 (2.35) 19913 (2.34)

Note:

  • Bits per dimension in brackets.
  • 'iws' stands for importance weighted samples. More samples means tighter log likelihood lower bound. The bound converges to the actual log likelihood as the number of samples goes to infinity [5]. Note that the model is always trained with the ELBO (1 sample).
  • Each pixel in the images is modeled independently. The likelihood is Bernoulli for binary images, and discretized mixture of logistics with 10 components [6] otherwise.
  • One day I'll get around to evaluating the IW bound on all datasets with 10000 samples.

Supported datasets

  • Statically binarized MNIST [7], see Hugo Larochelle's website http://www.cs.toronto.edu/~larocheh/public/datasets/
  • SVHN
  • CIFAR10
  • CelebA rescaled and cropped to 64x64 – see code for details. The path in experiment.data.DatasetLoader has to be modified
  • binary multi-dSprites: 64x64 RGB shapes (0 to 2) in each image

Samples

Binarized MNIST

MNIST samples

Multi-dSprites

multi-dSprites samples

SVHN

SVHN samples

CIFAR

CIFAR samples

CelebA

CelebA samples

Hierarchical representations

Here we try to visualize the representations learned by individual layers. We can get a rough idea of what's going on at layer i as follows:

  • Sample latent variables from all layers above layer i (Eq. 1).

  • With these variables fixed, take S conditional samples at layer i (Eq. 2). Note that they are all conditioned on the same samples. These correspond to one row in the images below.

  • For each of these samples (each small image in the images below), pick the mode/mean of the conditional distribution of each layer below (Eq. 3).

  • Finally, sample an image x given the latent variables (Eq. 4).

Formally:

                

where s = 1, ..., S denotes the sample index.

The equations above yield S sample images conditioned on the same values of z for layers i+1 to L. These S samples are shown in one row of the images below. Notice that samples from each row are almost identical when the variability comes from a low-level layer, as such layers mostly model local structure and details. Higher layers on the other hand model global structure, and we observe more and more variability in each row as we move to higher layers. When the sampling happens in the top layer (i = L), all samples are completely independent, even within a row.

Binarized MNIST: layers 4, 8, 10, and 12 (top layer)

MNIST layers 4   MNIST layers 8

MNIST layers 10   MNIST layers 12

SVHN: layers 4, 10, 13, and 15 (top layer)

SVHN layers 4   SVHN layers 10

SVHN layers 13   SVHN layers 15

CIFAR: layers 3, 7, 10, and 15 (top layer)

CIFAR layers 3   CIFAR layers 7

CIFAR layers 10   CIFAR layers 15

CelebA: layers 6, 11, 16, and 20 (top layer)

CelebA layers 6

CelebA layers 11

CelebA layers 16

CelebA layers 20

Multi-dSprites: layers 3, 7, 10, and 12 (top layer)

MNIST layers 4   MNIST layers 8

MNIST layers 10   MNIST layers 12

Experimental details

I did not perform an extensive hyperparameter search, but this worked pretty well:

  • Downsampling by a factor of 2 in the beginning of inference. After that, activations are downsampled 4 times for 64x64 images (CelebA and multi-dSprites), and 3 times otherwise. The spatial size of the final feature map is always 2x2. Between these downsampling steps there is approximately the same number of stochastic layers.
  • 4 residual blocks between stochastic layers. Haven't tried with more than 4 though, as models become quite big and we get diminishing returns.
  • The deterministic parts of bottom-up and top-down architecture are (almost) perfectly mirrored for simplicity.
  • Stochastic layers have spatial random variables, and the number of rvs per "location" (i.e. number of channels of the feature map after sampling from a layer) is 32 in all layers.
  • All other feature maps in deterministic paths have 64 channels.
  • Skip connections in the generative model (--skip).
  • Gated residual blocks (--gated).
  • Learned prior of the top layer (--learn-top-prior).
  • A form of data-dependent initialization of weights (--data-dep-init). See code for details.
  • freebits=1.0 in experiments with more than 6 stochastic layers, and 0.5 for smaller models.
  • For everything else, see _add_args() in experiment/experiment_manager.py.

With these settings, the number of parameters is roughly 1M per stochastic layer. I tried to control for this by experimenting e.g. with half the number of layers but twice the number of residual blocks, but it looks like the number of stochastic layers is what matters the most.

References

[1] CK Sønderby, T Raiko, L Maaløe, SK Sønderby, O Winther. Ladder Variational Autoencoders, NIPS 2016

[2] L Maaløe, M Fraccaro, V Liévin, O Winther. BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling, NeurIPS 2019

[3] DP Kingma, T Salimans, R Jozefowicz, X Chen, I Sutskever, M Welling. Improved Variational Inference with Inverse Autoregressive Flow, NIPS 2016

[4] I Higgins, L Matthey, A Pal, C Burgess, X Glorot, M Botvinick, S Mohamed, A Lerchner. beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework, ICLR 2017

[5] Y Burda, RB Grosse, R Salakhutdinov. Importance Weighted Autoencoders, ICLR 2016

[6] T Salimans, A Karpathy, X Chen, DP Kingma. PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications, ICLR 2017

[7] H Larochelle, I Murray. The neural autoregressive distribution estimator, AISTATS 2011

Owner
Andrea Dittadi
PhD student at DTU Compute | representation learning, deep generative models
Andrea Dittadi
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
Face Recognition plus identification simply and fast | Python

PyFaceDetection Face Recognition plus identification simply and fast Ubuntu Setup sudo pip3 install numpy sudo pip3 install cmake sudo pip3 install dl

Peyman Majidi Moein 16 Sep 22, 2022
Code of the paper "Multi-Task Meta-Learning Modification with Stochastic Approximation".

Multi-Task Meta-Learning Modification with Stochastic Approximation This repository contains the code for the paper "Multi-Task Meta-Learning Modifica

Andrew 3 Jan 05, 2022
Repo for parser tensorflow(.pb) and tflite(.tflite)

tfmodel_parser .pb file is the format of tensorflow model .tflite file is the format of tflite model, which usually used in mobile devices before star

1 Dec 23, 2021
GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️

GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️ This repo contains a PyTorch implementation of the original GAT paper ( 🔗 Veličković et

Aleksa Gordić 1.9k Jan 09, 2023
Adversarial Attacks on Probabilistic Autoregressive Forecasting Models.

Attack-Probabilistic-Models This is the source code for Adversarial Attacks on Probabilistic Autoregressive Forecasting Models. This repository contai

SRI Lab, ETH Zurich 25 Sep 14, 2022
Official source code to CVPR'20 paper, "When2com: Multi-Agent Perception via Communication Graph Grouping"

When2com: Multi-Agent Perception via Communication Graph Grouping This is the PyTorch implementation of our paper: When2com: Multi-Agent Perception vi

34 Nov 09, 2022
YOLOv7 - Framework Beyond Detection

🔥🔥🔥🔥 YOLO with Transformers and Instance Segmentation, with TensorRT acceleration! 🔥🔥🔥

JinTian 3k Jan 01, 2023
PyTorch implementation of "Optimization Planning for 3D ConvNets"

Optimization-Planning-for-3D-ConvNets Code for the ICML 2021 paper: Optimization Planning for 3D ConvNets. Authors: Zhaofan Qiu, Ting Yao, Chong-Wah N

Zhaofan Qiu 2 Jan 12, 2022
Notes, programming assignments and quizzes from all courses within the Coursera Deep Learning specialization offered by deeplearning.ai

Coursera-deep-learning-specialization - Notes, programming assignments and quizzes from all courses within the Coursera Deep Learning specialization offered by deeplearning.ai: (i) Neural Networks an

Aman Chadha 1.7k Jan 08, 2023
Convolutional neural network that analyzes self-generated images in a variety of languages to find etymological similarities

This project is a convolutional neural network (CNN) that analyzes self-generated images in a variety of languages to find etymological similarities. Specifically, the goal is to prove that computer

1 Feb 03, 2022
Pansharpening by convolutional neural networks in the full resolution framework

Z-PNN: Zoom Pansharpening Neural Network Pansharpening by convolutional neural networks in the full resolution framework is a deep learning method for

20 Nov 24, 2022
Photographic Image Synthesis with Cascaded Refinement Networks - Pytorch Implementation

Photographic Image Synthesis with Cascaded Refinement Networks-Pytorch (https://arxiv.org/abs/1707.09405) This is a Pytorch implementation of cascaded

Soumya Tripathy 63 Mar 27, 2022
Unsupervised Representation Learning via Neural Activation Coding

Neural Activation Coding This repository contains the code for the paper "Unsupervised Representation Learning via Neural Activation Coding" published

yookoon park 5 May 26, 2022
AutoDeeplab / auto-deeplab / AutoML for semantic segmentation, implemented in Pytorch

AutoML for Image Semantic Segmentation Currently this repo contains the only working open-source implementation of Auto-Deeplab which, by the way out-

AI Necromancer 299 Dec 17, 2022
PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.

PySlowFast PySlowFast is an open source video understanding codebase from FAIR that provides state-of-the-art video classification models with efficie

Meta Research 5.3k Jan 03, 2023
Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Det

123 Jan 04, 2023
Reinforcement learning models in ViZDoom environment

DoomNet DoomNet is a ViZDoom agent trained by reinforcement learning. The agent is a neural network that outputs a probability of actions given only p

Andrey Kolishchak 126 Dec 09, 2022
TPH-YOLOv5: Improved YOLOv5 Based on Transformer Prediction Head for Object Detection on Drone-Captured Scenarios

TPH-YOLOv5 This repo is the implementation of "TPH-YOLOv5: Improved YOLOv5 Based on Transformer Prediction Head for Object Detection on Drone-Captured

cv516Buaa 439 Dec 22, 2022
Toontown House CT Edition

Toontown House: Classic Toontown House Classic source that should just work. ❓ W

Open Source Toontown Servers 5 Jan 09, 2022