PyTorch Implementation of Vector Quantized Variational AutoEncoders.

Overview

Pytorch implementation of VQVAE.

This paper combines 2 tricks:

  1. Vector Quantization (check out this amazing blog for better understanding.)
  2. Straight-Through (It solves the problem of back-propagation through discrete latent variables, which are intractable.)

architecture

This model has a neural network encoder and decoder, and a prior just like the vanila Variational AutoEncoder(VAE). But this model also has a latent embedding space called codebook(size: K x D). Here, K is the size of latent space and D is the dimension of each embedding e.

In vanilla variational autoencoders, the output from the encoder z(x) is used to parameterize a Normal/Gaussian distribution, which is sampled from to get a latent representation z of the input x using the 'reparameterization trick'. This latent representation is then passed to the decoder. However, In VQVAEs, z(x) is used as a "key" to do nearest neighbour lookup into the embedding codebook c, and get zq(x), the closest embedding in the space. This is called Vector Quantization(VQ) operation. Then, zq(x) is passed to the decoder, which reconstructs the input x. The decoder can either parameterize p(x|z) as the mean of Normal distribution using a transposed convolution layer like in vannila VAE, or it can autoregressively generate categorical distribution over [0,255] pixel values like PixelCNN. In this project, the first approach is used.

The loss function is combined of 3 components:

  1. Regular Reconstruction loss
  2. Vector Quantization loss
  3. Commitment loss

Vector Quantization loss encourages the items in the codebook to move closer to the encoder output ||sg[ze(x) - e||^2] and Commitment loss encourages the output of the encoder to be close to embedding it picked, to commit to its codebook embedding. ||ze(x) - sg[e]]||^2 . commitment loss is multiplied with a constant beta, which is 1.0 for this project. Here, sg means "stop-gradient". Which means we don't propagate the gradients with respect to that term.

Results:

The Model is trained on MNIST and CIFAR10 datasets.

Target πŸ‘‰ Reconstructed Image


πŸ‘‰

πŸ‘‰

gif

Details:

  1. Trained models for MNIST and CIFAR10 are in the Trained models directory.
  2. Hidden size of the bottleneck(z) for MNIST and CIFAR10 is 128, 256 respectively.
Owner
Vrushank Changawala
Vrushank Changawala
This repository contains python code necessary to replicated the experiments performed in our paper "Invariant Ancestry Search"

InvariantAncestrySearch This repository contains python code necessary to replicated the experiments performed in our paper "Invariant Ancestry Search

Phillip Bredahl Mogensen 0 Feb 02, 2022
Materials for upcoming beginner-friendly PyTorch course (work in progress).

Learn PyTorch for Deep Learning (work in progress) I'd like to learn PyTorch. So I'm going to use this repo to: Add what I've learned. Teach others in

Daniel Bourke 2.3k Dec 29, 2022
Technical Analysis library in pandas for backtesting algotrading and quantitative analysis

bta-lib - A pandas based Technical Analysis Library bta-lib is pandas based technical analysis library and part of the backtrader family. Links Main P

DRo 393 Dec 20, 2022
Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

1 Feb 14, 2022
Aspect-Sentiment-Multiple-Opinion Triplet Extraction (NLPCC 2021)

The code and data for the paper "Aspect-Sentiment-Multiple-Opinion Triplet Extraction" Requirements Python 3.6.8 torch==1.2.0 pytorch-transformers==1.

ζ…’εŠζ‹ 5 Jul 02, 2022
Raindrop strategy for Irregular time series

Graph-Guided Network For Irregularly Sampled Multivariate Time Series Overview This repository contains processed datasets and implementation code for

Zitnik Lab @ Harvard 74 Jan 03, 2023
DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification Created by Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Ch

Yongming Rao 414 Jan 01, 2023
Face and Body Tracking for VRM 3D models on the web.

Kalidoface 3D - Face and Full-Body tracking for Vtubing on the web! A sequal to Kalidoface which supports Live2D avatars, Kalidoface 3D is a web app t

Rich 257 Jan 02, 2023
City-seeds - A random generator of cultural characteristics intended to spark ideas and help draw threads

City Seeds This is a random generator of cultural characteristics intended to sp

Aydin O'Leary 2 Mar 12, 2022
NVIDIA Merlin is an open source library providing end-to-end GPU-accelerated recommender systems, from feature engineering and preprocessing to training deep learning models and running inference in production.

NVIDIA Merlin NVIDIA Merlin is an open source library designed to accelerate recommender systems on NVIDIA’s GPUs. It enables data scientists, machine

419 Jan 03, 2023
Multi Agent Reinforcement Learning for ROS in 2D Simulation Environments

IROS21 information To test the code and reproduce the experiments, follow the installation steps in Installation.md. Afterwards, follow the steps in E

11 Oct 29, 2022
The DL Streamer Pipeline Zoo is a catalog of optimized media and media analytics pipelines.

The DL Streamer Pipeline Zoo is a catalog of optimized media and media analytics pipelines. It includes tools for downloading pipelines and their dependencies and tools for measuring their performace

8 Dec 04, 2022
JORLDY an open-source Reinforcement Learning (RL) framework provided by KakaoEnterprise

Repository for Open Source Reinforcement Learning Framework JORLDY

Kakao Enterprise Corp. 330 Dec 30, 2022
Code for the paper "Generative design of breakwaters usign deep convolutional neural network as a surrogate model"

Generative design of breakwaters usign deep convolutional neural network as a surrogate model This repository contains the code for the paper "Generat

2 Apr 10, 2022
Official implementation of "Can You Spot the Chameleon? Adversarially Camouflaging Images from Co-Salient Object Detection" in CVPR 2022.

Jadena Official implementation of "Can You Spot the Chameleon? Adversarially Camouflaging Images from Co-Salient Object Detection" in CVPR 2022. arXiv

Qing Guo 13 Nov 29, 2022
subpixel: A subpixel convnet for super resolution with Tensorflow

subpixel: A subpixel convolutional neural network implementation with Tensorflow Left: input images / Right: output images with 4x super-resolution af

Atrium LTS 2.1k Dec 23, 2022
Sequence lineage information extracted from RKI sequence data repo

Pango lineage information for German SARS-CoV-2 sequences This repository contains a join of the metadata and pango lineage tables of all German SARS-

Cornelius Roemer 24 Oct 26, 2022
CUDA Python Low-level Bindings

CUDA Python Low-level Bindings

NVIDIA Corporation 529 Jan 03, 2023
Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch

MeMOT - Pytorch (wip) Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch. This paper is just one in a line of work, but importan

Phil Wang 15 May 09, 2022
A Pytorch implementation of CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets"

RSG: A Simple but Effective Module for Learning Imbalanced Datasets (CVPR 2021) A Pytorch implementation of our CVPR 2021 paper "RSG: A Simple but Eff

120 Dec 12, 2022