Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

Owner
Jamie Townsend
Jamie Townsend
The missing CMake project initializer

cmake-init - The missing CMake project initializer Opinionated CMake project initializer to generate CMake projects that are FetchContent ready, separ

1k Jan 01, 2023
The 2nd Version Of Slothybot

SlothyBot Go to this website: "https://bitly.com/SlothyBot" The 2nd Version Of Slothybot. The Bot Has Many Features, Such As: Moderation Commands; Kic

Slothy 0 Jun 01, 2022
Warning: This project does not have any current developer. See bellow.

Pylearn2: A machine learning research library Warning : This project does not have any current developer. We will continue to review pull requests and

Laboratoire d’Informatique des Systèmes Adaptatifs 2.7k Dec 26, 2022
A minimal solution to hand motion capture from a single color camera at over 100fps. Easy to use, plug to run.

Minimal Hand A minimal solution to hand motion capture from a single color camera at over 100fps. Easy to use, plug to run. This project provides the

Yuxiao Zhou 824 Jan 07, 2023
Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)

Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021) Jiaxi Jiang, Kai Zhang, Radu Timofte Computer Vision Lab, ETH Zurich, Switzerland 🔥

Jiaxi Jiang 282 Jan 02, 2023
Joint project of the duo Hacker Ninjas

Project Smoothie Společný projekt dua Hacker Ninjas. První pokus o hříčku po třech týdnech učení se programování. Jakub Kolář e:\

Jakub Kolář 2 Jan 07, 2022
An Open Source Machine Learning Framework for Everyone

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

170.1k Jan 04, 2023
Differentiable Surface Triangulation

Differentiable Surface Triangulation This is our implementation of the paper Differentiable Surface Triangulation that enables optimization for any pe

61 Dec 07, 2022
OOD Dataset Curator and Benchmark for AI-aided Drug Discovery

🔥 DrugOOD 🔥 : OOD Dataset Curator and Benchmark for AI Aided Drug Discovery This is the official implementation of the DrugOOD project, this is the

108 Dec 17, 2022
This package contains deep learning models and related scripts for RoseTTAFold

RoseTTAFold This package contains deep learning models and related scripts to run RoseTTAFold This repository is the official implementation of RoseTT

1.6k Jan 03, 2023
General Vision Benchmark, a project from OpenGVLab

Introduction We build GV-B(General Vision Benchmark) on Classification, Detection, Segmentation and Depth Estimation including 26 datasets for model e

174 Dec 27, 2022
Source code for CVPR 2021 paper "Riggable 3D Face Reconstruction via In-Network Optimization"

Riggable 3D Face Reconstruction via In-Network Optimization Source code for CVPR 2021 paper "Riggable 3D Face Reconstruction via In-Network Optimizati

130 Jan 02, 2023
Codes for NeurIPS 2021 paper "Adversarial Neuron Pruning Purifies Backdoored Deep Models"

Adversarial Neuron Pruning Purifies Backdoored Deep Models Code for NeurIPS 2021 "Adversarial Neuron Pruning Purifies Backdoored Deep Models" by Dongx

Dongxian Wu 31 Dec 11, 2022
Finite-temperature variational Monte Carlo calculation of uniform electron gas using neural canonical transformation.

CoulombGas This code implements the neural canonical transformation approach to the thermodynamic properties of uniform electron gas. Building on JAX,

FermiFlow 9 Mar 03, 2022
ConE: Cone Embeddings for Multi-Hop Reasoning over Knowledge Graphs

ConE: Cone Embeddings for Multi-Hop Reasoning over Knowledge Graphs This is the code of paper ConE: Cone Embeddings for Multi-Hop Reasoning over Knowl

MIRA Lab 33 Dec 07, 2022
JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library

JAX bindings to FINUFFT This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) lib

Dan Foreman-Mackey 32 Oct 15, 2022
Unofficial pytorch implementation of 'Image Inpainting for Irregular Holes Using Partial Convolutions'

pytorch-inpainting-with-partial-conv Official implementation is released by the authors. Note that this is an ongoing re-implementation and I cannot f

Naoto Inoue 525 Jan 01, 2023
ZEBRA: Zero Evidence Biometric Recognition Assessment

ZEBRA: Zero Evidence Biometric Recognition Assessment license: LGPLv3 - please reference our paper version: 2020-06-11 author: Andreas Nautsch (EURECO

Voice Privacy Challenge 2 Dec 12, 2021
StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Demo video: CVPR 2021 Oral: Single Channel Manipulation: Localized or attribu

Zongze Wu 267 Dec 30, 2022
The official implementation of the CVPR2021 paper: Decoupled Dynamic Filter Networks

Decoupled Dynamic Filter Networks This repo is the official implementation of CVPR2021 paper: "Decoupled Dynamic Filter Networks". Introduction DDF is

F.S.Fire 180 Dec 30, 2022