Code for ICLR 2021 Paper, "Anytime Sampling for Autoregressive Models via Ordered Autoencoding"

Overview

Anytime Autoregressive Model

Anytime Sampling for Autoregressive Models via Ordered Autoencoding , ICLR 21

​ Yilun Xu, Yang Song, Sahaj Gara, Linyuan Gong, Rui Shu, Aditya Grover, Stefano Ermon

A new family of autoregressive model that enables anytime sampling​! 😃

Experiment 1: Image generation

Training:

  • Step 1: Pretrain VQ-VAE with full code length:
python vqvae.py --hidden-size latent-size --k codebook-size --dataset name-of-dataset --data-folder paht-to-dataset  --out-path path-to-model --pretrain

latent-size: latent code length
codebook-size: codebook size
name-of-dataset: mnist / cifar10 / celeba
path-to-dataset: path to the roots of dataset
path-to-model: path to save checkpoints
  • Step 2: Train ordered VQ-VAE:
python vqvae.py --hidden-size latent-size --k codebook-size --dataset name-of-dataset --data-folder paht-to-dataset  --out-path path-to-model --restore-checkpoint path-to-checkpoint --lr learning-rate

latent-size: latent code length
codebook-size: codebook size
name-of-dataset: mnist / cifar10 / celeba
path-to-dataset: path to the roots of dataset
path-to-model: path to save checkpoints
path-to-checkpoint: the path of the best checkpoint in Step 1
learning-rate: learning rate (recommended:1e-3)

  • Step 3: Train autoregressive model
python train_ar.py --task integer_sequence_modeling \
path-to-dumped-codes --vocab-size codebook-size --tokens-per-sample latent-size \
--ae-dataset name-of-dataset --ae-data-path path to the roots of dataset --ae-checkpoint path-to-checkpoint --ae-batch-size 512 \
--arch transformer_lm --dropout dropout-rate --attention-dropout dropout-rate --activation-dropout dropout-rate \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --weight-decay 0.1 --clip-norm 0.0 \
--lr 0.002 --lr-scheduler inverse_sqrt --warmup-updates 3000 --warmup-init-lr 1e-07 \
--max-sentences ar-batch-size \
--fp16 \
--max-update iterations \
--seed 2 \
--log-format json --log-interval 10000000 --no-epoch-checkpoints --no-last-checkpoints \
--save-dir path-to-model

path-to-dumped-codes: path to the dumped codes of VQ-VAE model (fasten training process)
dropout-rate: dropout rate
latent-size: latent code length
codebook-size: codebook size
name-of-dataset: mnist / cifar10 / celeba
path-to-dataset: path to the roots of dataset
path-to-model: path to save checkpoints
path-to-checkpoint: the path of the best checkpoint in Step 2
ar-batch-size: batch size of autorregressive model
iterations: training iterations

Anytime sampling (Inference):

python3 generate.py --n-samples number-of-samples --out-path paht-to-img \
--tokens-per-sample latent-size --vocab-size codebook-size --tokens-per-target code-num \
--ae-checkpoint path-to-ae --ae-batch-size 512 \
--ar-checkpoint path-to-ar --ar-batch-size batch-size
(--ae_celeba --ae_mnist)
number-of-samples: number of samples to be generated
path-to-img: path to the generated samples
latent-size: latent code length
codebook-size: codebook size
code-num: number of codes used to generated (Anytime sampling!)
path-to-ae: path to the VQ-VAE checkpoint in Step 2
path-to-ar: path to the Transformer checkpoint in Step 3
batch-size: batch size for Transforer
ae_celeba: store_true for generating CelebA
ae_mnist: store_true for generating mnist

Experiment 2: Audio Generation

Firstly cd audio-wave/src.

Training:

  • Step 1: Pretrain VQ-VAE with full code length:
python3 main.py -ex ../configuration/experimens_wave_vq_whole_bigger.jason
  • Step 2: Train ordered VQ-VAE:
python3 main.py -ex ../configuration/experimens_wave_vq_whole_bigger_u.json
  • Step 3: Training Transformerr models:

    • A more step: dump the codebook by: (Will merge this step in future version)
    python3 main.py -ex ../configuration/experimens_wave_vq_whole_bigger_u.json --dump
python train_ar.py --task integer_sequence_modeling \
path-to-dumped-codes --vocab-size codebook-size --tokens-per-sample latent-size \
--arch transformer_lm --dropout dropout-rate --attention-dropout dropout-rate --activation-dropout dropout-rate \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --weight-decay 0.1 --clip-norm 0.0 \
--lr 0.002 --lr-scheduler inverse_sqrt --warmup-updates 3000 --warmup-init-lr 1e-07 \
--max-sentences ar-batch-size \
--fp16 \
--max-update iterations \
--seed 2 \
--log-format json --log-interval 10000000 --no-epoch-checkpoints --no-last-checkpoints \
--save-dir path-to-model

path-to-dumped-codes: path to the dumped codes of VQ-VAE model (fasten training process)
dropout-rate: dropout rate
latent-size: latent code length
codebook-size: codebook size
name-of-dataset: mnist / cifar10 / celeba
path-to-dataset: path to the roots of dataset
path-to-model: path to save checkpoints
ar-batch-size: batch size of autorregressive model
iterations: training iterations

Anytime sampling (Inference):

python3 generate.py --n-samples number-of-samples --out-path paht-to-img \
--tokens-per-sample latent-size --vocab-size codebook-size --tokens-per-target code-num \
--ar-checkpoint path-to-ar --ar-batch-size batch-size

number-of-samples: number of samples to be generated
path-to-img: path to the generated samples
latent-size: latent code length
codebook-size: codebook size
code-num: number of codes used to generated (Anytime sampling!)
path-to-ar: path to the Transformer checkpoint in Step 3
batch-size: batch size for Transforer

Citation

@inproceedings{
xu2021anytime,
title={Anytime Sampling for Autoregressive Models via Ordered Autoencoding},
author={Yilun Xu and Yang Song and Sahaj Garg and Linyuan Gong and Rui Shu and Aditya Grover and Stefano Ermon},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=TSRTzJnuEBS}
}
Owner
Yilun Xu
Yilun Xu
Computer Vision and Pattern Recognition, NUS CS4243, 2022

CS4243_2022 Computer Vision and Pattern Recognition, NUS CS4243, 2022 Cloud Machine #1 : Google Colab (Free GPU) Follow this Notebook installation : h

Xavier Bresson 142 Dec 15, 2022
Code repo for "Cross-Scale Internal Graph Neural Network for Image Super-Resolution" (NeurIPS'20)

IGNN Code repo for "Cross-Scale Internal Graph Neural Network for Image Super-Resolution" [paper] [supp] Prepare datasets 1 Download training dataset

Shangchen Zhou 278 Jan 03, 2023
[MedIA2021]MIDeepSeg: Minimally Interactive Segmentation of Unseen Objects from Medical Images Using Deep Learning

MIDeepSeg: Minimally Interactive Segmentation of Unseen Objects from Medical Images Using Deep Learning [MedIA or Arxiv] and [Demo] This repository pr

Healthcare Intelligence Laboratory 92 Dec 08, 2022
Python scripts for performing lane detection using the LSTR model in ONNX

ONNX LSTR Lane Detection Python scripts for performing lane detection using the Lane Shape Prediction with Transformers (LSTR) model in ONNX. Requirem

Ibai Gorordo 29 Aug 30, 2022
Xintao 1.4k Dec 25, 2022
3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

55 Dec 19, 2022
Multiple paper open-source codes of the Microsoft Research Asia DKI group

📫 Paper Code Collection (MSRA DKI Group) This repo hosts multiple open-source codes of the Microsoft Research Asia DKI Group. You could find the corr

Microsoft 249 Jan 08, 2023
Autonomous Driving on Curvy Roads without Reliance on Frenet Frame: A Cartesian-based Trajectory Planning Method

C++/ROS Source Codes for "Autonomous Driving on Curvy Roads without Reliance on Frenet Frame: A Cartesian-based Trajectory Planning Method" published in IEEE Trans. Intelligent Transportation Systems

Bai Li 88 Dec 23, 2022
Dataset and Source code of paper 'Enhancing Keyphrase Extraction from Academic Articles with their Reference Information'.

Enhancing Keyphrase Extraction from Academic Articles with their Reference Information Overview Dataset and code for paper "Enhancing Keyphrase Extrac

15 Nov 24, 2022
Extreme Lightwegith Portrait Segmentation

Extreme Lightwegith Portrait Segmentation Please go to this link to download code Requirements python 3 pytorch = 0.4.1 torchvision==0.2.1 opencv-pyt

HYOJINPARK 59 Dec 16, 2022
Cycle Consistent Adversarial Domain Adaptation (CyCADA)

Cycle Consistent Adversarial Domain Adaptation (CyCADA) A pytorch implementation of CyCADA. If you use this code in your research please consider citi

Hyunwoo Ko 2 Jan 10, 2022
Source code for our paper "Empathetic Response Generation with State Management"

Source code for our paper "Empathetic Response Generation with State Management" this repository is maintained by both Jun Gao and Yuhan Liu Model Ove

Yuhan Liu 3 Oct 08, 2022
FIGARO: Generating Symbolic Music with Fine-Grained Artistic Control

FIGARO: Generating Symbolic Music with Fine-Grained Artistic Control by Dimitri von Rütte, Luca Biggio, Yannic Kilcher, Thomas Hofmann FIGARO: Generat

Dimitri 83 Jan 07, 2023
A repository for the paper "Improved Adversarial Systems for 3D Object Generation and Reconstruction".

Improved Adversarial Systems for 3D Object Generation and Reconstruction: This is a repository for the paper "Improved Adversarial Systems for 3D Obje

Edward Smith 188 Dec 25, 2022
Code for the head detector (HeadHunter) proposed in our CVPR 2021 paper Tracking Pedestrian Heads in Dense Crowd.

Head Detector Code for the head detector (HeadHunter) proposed in our CVPR 2021 paper Tracking Pedestrian Heads in Dense Crowd. The head_detection mod

Ramana Sundararaman 76 Dec 06, 2022
The ICS Chat System project for NYU Shanghai Fall 2021

ICS_Chat_System [Catenger] This is the ICS Chat System project for NYU Shanghai Fall 2021 Creators: Shavarsh Melikyan, Skyler Chen and Arghya Sarkar,

1 Dec 20, 2021
Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding

2D-TAN (Optimized) Introduction This is an optimized re-implementation repository for AAAI'2020 paper: Learning 2D Temporal Localization Networks for

Joya Chen 112 Dec 31, 2022
Synthetic LiDAR sequential point cloud dataset with point-wise annotations

SynLiDAR dataset: Learning From Synthetic LiDAR Sequential Point Cloud This is official repository of the SynLiDAR dataset. For technical details, ple

78 Dec 27, 2022
Pytorch implementation of Learning Rate Dropout.

Learning-Rate-Dropout Pytorch implementation of Learning Rate Dropout. Paper Link: https://arxiv.org/pdf/1912.00144.pdf Train ResNet-34 for Cifar10: r

42 Nov 25, 2022
A web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks

This project is a web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks. Thanks for NVlabs' excelle

K.L. 150 Dec 15, 2022