MADE (Masked Autoencoder Density Estimation) implementation in PyTorch

Overview

pytorch-made

This code is an implementation of "Masked AutoEncoder for Density Estimation" by Germain et al., 2015. The core idea is that you can turn an auto-encoder into an autoregressive density model just by appropriately masking the connections in the MLP, ordering the input dimensions in some way and making sure that all outputs only depend on inputs earlier in the list. Like other autoregressive models (char-rnn, pixel cnns, etc), evaluating the likelihood is very cheap (a single forward pass), but sampling is linear in the number of dimensions.

figure 1

The authors of the paper also published code here, but it's a bit wordy, sprawling and in Theano. Hence my own shot at it with only ~150 lines of code and PyTorch <3.

examples

First we download the binarized mnist dataset. Then we can reproduce the first point on the plot of Figure 2 by training a 1-layer MLP of 500 units with only a single mask, and using a single fixed (but random) ordering as so:

python run.py --data-path binarized_mnist.npz -q 500

which converges at binary cross entropy loss of 94.5, as shown in the paper. We can then simultaneously train a larger model ensemble (with weight sharing in the one MLP) and average over all of the models at test time. For instance, we can use 10 orderings (-n 10) and also average over the 10 at inference time (-s 10):

python run.py --data-path binarized_mnist.npz -q 500 -n 10 -s 10

which gives a much better test loss of 79.3, but at the cost of multiple forward passes. I was not able to reproduce single-forward-pass gains that the paper alludes to when training with multiple masks, might be doing something wrong.

usage

The core class is MADE, found in made.py. It inherits from PyTorch nn.Module so you can "slot it into" larger architectures quite easily. To instantiate MADE on 1D inputs of MNIST digits for example (which have 28*28 pixels), using one hidden layer of 500 neurons, and using a single but random ordering we would do:

model = MADE(28*28, [500], 28*28, num_masks=1, natural_ordering=False)

The reason we plug the size of the output (3rd argument) into MADE is that one might want to use relatively complicated output distributions, for example a gaussian distribution would normally be parameterized by a mean and a standard deviation for each dimension, or you could bin the output range into buckets and output logprobs for a softmax, or mixture parameters, etc. In the simplest example in this code we use binary predictions, where are only parameterized by one number, hence the number of the input dimensions happens to equal the number of outputs.

License

MIT

Owner
Andrej
I like to train Deep Neural Nets on large datasets.
Andrej
Point detection through multi-instance deep heatmap regression for sutures in endoscopy

Suture detection PyTorch This repo contains the reference implementation of suture detection model in PyTorch for the paper Point detection through mu

artificial intelligence in the area of cardiovascular healthcare 3 Jul 16, 2022
Security evaluation module with onnx, pytorch, and SecML.

🚀 🐼 🔥 PandaVision Integrate and automate security evaluations with onnx, pytorch, and SecML! Installation Starting the server without Docker If you

Maura Pintor 11 Apr 12, 2022
Build and run Docker containers leveraging NVIDIA GPUs

NVIDIA Container Toolkit Introduction The NVIDIA Container Toolkit allows users to build and run GPU accelerated Docker containers. The toolkit includ

NVIDIA Corporation 15.6k Jan 01, 2023
Data augmentation for NLP, accepted at EMNLP 2021 Findings

AEDA: An Easier Data Augmentation Technique for Text Classification This is the code for the EMNLP 2021 paper AEDA: An Easier Data Augmentation Techni

Akbar Karimi 81 Dec 09, 2022
GDR-Net: Geometry-Guided Direct Regression Network for Monocular 6D Object Pose Estimation. (CVPR 2021)

GDR-Net This repo provides the PyTorch implementation of the work: Gu Wang, Fabian Manhardt, Federico Tombari, Xiangyang Ji. GDR-Net: Geometry-Guided

169 Jan 07, 2023
PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis This is a PyTorch implementation of the Deep Streaming Linear Discriminant

Tyler Hayes 41 Dec 25, 2022
O-CNN: Octree-based Convolutional Neural Networks for 3D Shape Analysis

O-CNN This repository contains the implementation of our papers related with O-CNN. The code is released under the MIT license. O-CNN: Octree-based Co

Microsoft 607 Dec 28, 2022
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

20 May 28, 2022
Agile SVG maker for python

Agile SVG Maker Need to draw hundreds of frames for a GIF? Need to change the style of all pictures in a PPT? Need to draw similar images with differe

SemiWaker 4 Sep 25, 2022
[arXiv] What-If Motion Prediction for Autonomous Driving ❓🚗💨

WIMP - What If Motion Predictor Reference PyTorch Implementation for What If Motion Prediction [PDF] [Dynamic Visualizations] Setup Requirements The W

William Qi 96 Dec 29, 2022
[ICCV 2021] Released code for Causal Attention for Unbiased Visual Recognition

CaaM This repo contains the codes of training our CaaM on NICO/ImageNet9 dataset. Due to my recent limited bandwidth, this codebase is still messy, wh

Wang Tan 66 Dec 31, 2022
Machine Learning University: Accelerated Computer Vision Class

Machine Learning University: Accelerated Computer Vision Class This repository contains slides, notebooks, and datasets for the Machine Learning Unive

AWS Samples 1.3k Dec 28, 2022
Identify the emotion of multiple speakers in an Audio Segment

MevonAI - Speech Emotion Recognition Identify the emotion of multiple speakers in a Audio Segment Report Bug · Request Feature Try the Demo Here Table

Suyash More 110 Dec 03, 2022
Multi-task Multi-agent Soft Actor Critic for SMAC

Multi-task Multi-agent Soft Actor Critic for SMAC Overview The CARE formulti-task: Multi-Task Reinforcement Learning with Context-based Representation

RuanJingqing 8 Sep 30, 2022
SASM - simple crossplatform IDE for NASM, MASM, GAS and FASM assembly languages

SASM (SimpleASM) - простая кроссплатформенная среда разработки для языков ассемблера NASM, MASM, GAS, FASM с подсветкой синтаксиса и отладчиком. В SA

Dmitriy Manushin 5.6k Jan 06, 2023
Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic

Pytorch Implementation of Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic [Paper] [Colab is coming soon] Approach Example Usage To r

170 Jan 03, 2023
Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Optimization Algorithm,Immune Algorithm, Artificial Fish Swarm Algorithm, Differential Evolution and TSP(Traveling salesman)

scikit-opt Swarm Intelligence in Python (Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Algorithm, Immune Algorithm,A

郭飞 3.7k Jan 03, 2023
DyStyle: Dynamic Neural Network for Multi-Attribute-Conditioned Style Editing

DyStyle: Dynamic Neural Network for Multi-Attribute-Conditioned Style Editing Figure: Joint multi-attribute edits using DyStyle model. Great diversity

74 Dec 03, 2022
CTF Challenge for CSAW Finals 2021

Terminal Velocity Misc CTF Challenge for CSAW Finals 2021 This is a challenge I've had in mind for almost 15 years and never got around to building un

Jordan 6 Jul 30, 2022
OpenMMLab Semantic Segmentation Toolbox and Benchmark.

Documentation: https://mmsegmentation.readthedocs.io/ English | 简体中文 Introduction MMSegmentation is an open source semantic segmentation toolbox based

OpenMMLab 5k Dec 31, 2022