Code for the paper "Adversarial Generator-Encoder Networks"

Related tags

Deep Learninggan
Overview

This repository contains code for the paper

"Adversarial Generator-Encoder Networks" (AAAI'18) by Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky.

Pretrained models

This is how you can access the models used to generate figures in the paper.

  1. First install dev version of pytorch 0.2 and make sure you have jupyter notebook ready.

  2. Then download the models with the script:

bash download_pretrained.sh
  1. Run jupyter notebook and go through evaluate.ipynb.

Here is an example of samples and reconstructions for imagenet, celeba and cifar10 datasets generated with evaluate.ipynb.

Celeba

Samples Reconstructions

Cifar10

Samples Reconstructions

Tiny ImageNet

Samples Reconstructions

Training

Use age.py script to train a model. Here are the most important parameters:

  • --dataset: one of [celeba, cifar10, imagenet, svhn, mnist]
  • --dataroot: for datasets included in torchvision it is a directory where everything will be downloaded to; for imagenet, celeba datasets it is a path to a directory with folders train and val inside.
  • --image_size:
  • --save_dir: path to a folder, where checkpoints will be stored
  • --nz: dimensionality of latent space
  • -- batch_size: Batch size. Default 64.
  • --netG: .py file with generator definition. Searched in models directory
  • --netE: .py file with generator definition. Searched in models directory
  • --netG_chp: path to a generator checkpoint to load from
  • --netE_chp: path to an encoder checkpoint to load from
  • --nepoch: number of epoch to run
  • --start_epoch: epoch number to start from. Useful for finetuning.
  • --e_updates: Update plan for encoder. <num steps>;KL_fake:<weight>,KL_real:<weight>,match_z:<weight>,match_x:<weight>.
  • --g_updates: Update plan for generator. <num steps>;KL_fake:<weight>,match_z:<weight>,match_x:<weight>.

And misc arguments:

  • --workers: number of dataloader workers.
  • --ngf: controlles number of channels in generator
  • --ndf: controlles number of channels in encoder
  • --beta1: parameter for ADAM optimizer
  • --cpu: do not use GPU
  • --criterion: Parametric param or non-parametric nonparam way to compute KL. Parametric fits Gaussian into data, non-parametric is based on nearest neighbors. Default: param.
  • --KL: What KL to compute: qp or pq. Default is qp.
  • --noise: sphere for uniform on sphere or gaussian. Default sphere.
  • --match_z: loss to use as reconstruction loss in latent space. L1|L2|cos. Default cos.
  • --match_x: loss to use as reconstruction loss in data space. L1|L2|cos. Default L1.
  • --drop_lr: each drop_lr epochs a learning rate is dropped.
  • --save_every: controls how often intermediate results are stored. Default 50.
  • --manual_seed: random seed. Default 123.

Here is cmd you can start with:

Celeba

Let data_root to be a directory with two folders train, val, each with the images for corresponding split.

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --lr 0.0002 --nz 64 --batch_size 64 --netG dcgan64px --netE dcgan64px --nepoch 5 --drop_lr 5 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

It is beneficial to finetune the model with larger batch_size and stronger matching weight then:

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --start_epoch 5 --lr 0.0002 --nz 64 --batch_size 256 --netG dcgan64px --netE dcgan64px --nepoch 6 --drop_lr 5   --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' --g_updates '3;KL_fake:1,match_z:1000,match_x:0' --netE_chp  <save_dir>/netE_epoch_5.pth --netG_chp <save_dir>/netG_epoch_5.pth

Imagenet

python age.py --dataset imagenet --dataroot /path/to/imagenet_dir/ --save_dir <save_dir> --image_size 32 --save_dir ${pdir} --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 6 --drop_lr 3  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:2000,match_x:0' --workers 12

It can be beneficial to switch to 256 batch size after several epochs.

Cifar10

python age.py --dataset cifar10 --image_size 32 --save_dir <save_dir> --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 150 --drop_lr 40  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:1000,match_x:0'

Tested with python 2.7.

Implementation is based on pyTorch DCGAN code.

Citation

If you found this code useful please cite our paper

@inproceedings{DBLP:conf/aaai/UlyanovVL18,
  author    = {Dmitry Ulyanov and
               Andrea Vedaldi and
               Victor S. Lempitsky},
  title     = {It Takes (Only) Two: Adversarial Generator-Encoder Networks},
  booktitle = {{AAAI}},
  publisher = {{AAAI} Press},
  year      = {2018}
}
Owner
Dmitry Ulyanov
Co-Founder at in3D, Phd @ Skoltech
Dmitry Ulyanov
A custom-designed Spider Robot trained to walk using Deep RL in a PyBullet Simulation

SpiderBot_DeepRL Title: Implementation of Single and Multi-Agent Deep Reinforcement Learning Algorithms for a Walking Spider Robot Authors(s): Arijit

Arijit Dasgupta 9 Jul 28, 2022
Shared Attention for Multi-label Zero-shot Learning

Shared Attention for Multi-label Zero-shot Learning Overview This repository contains the implementation of Shared Attention for Multi-label Zero-shot

dathuynh 26 Dec 14, 2022
Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras

Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras This tutorial shows how to use Keras library to build deep ne

Marko Jocić 922 Dec 19, 2022
An official implementation of the Anchor DETR.

Anchor DETR: Query Design for Transformer-Based Detector Introduction This repository is an official implementation of the Anchor DETR. We encode the

MEGVII Research 276 Dec 28, 2022
MG-GCN: Scalable Multi-GPU GCN Training Framework

MG-GCN MG-GCN: multi-GPU GCN training framework. For more information, please read our paper. After cloning our repository, run git submodule update -

Translational Data Analytics (TDA) Lab @GaTech 6 Oct 24, 2022
Source Code for DialogBERT: Discourse-Aware Response Generation via Learning to Recover and Rank Utterances (https://arxiv.org/pdf/2012.01775.pdf)

DialogBERT This is a PyTorch implementation of the DialogBERT model described in DialogBERT: Neural Response Generation via Hierarchical BERT with Dis

Xiaodong Gu 67 Jan 06, 2023
Code for our paper at ECCV 2020: Post-Training Piecewise Linear Quantization for Deep Neural Networks

PWLQ Updates 2020/07/16 - We are working on getting permission from our institution to release our source code. We will release it once we are granted

54 Dec 15, 2022
Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

Han Xu 129 Dec 11, 2022
Text and code for the forthcoming second edition of Think Bayes, by Allen Downey.

Think Bayes 2 by Allen B. Downey The HTML version of this book is here. Think Bayes is an introduction to Bayesian statistics using computational meth

Allen Downey 1.5k Jan 08, 2023
A distributed deep learning framework that supports flexible parallelization strategies.

FlexFlow FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization stra

528 Dec 25, 2022
Generating Images with Recurrent Adversarial Networks

Generating Images with Recurrent Adversarial Networks Python (Theano) implementation of Generating Images with Recurrent Adversarial Networks code pro

Daniel Jiwoong Im 121 Sep 08, 2022
A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

George Gunter 4 Nov 14, 2022
World Models with TensorFlow 2

World Models This repo reproduces the original implementation of World Models. This implementation uses TensorFlow 2.2. Docker The easiest way to hand

Zac Wellmer 234 Nov 30, 2022
Real-time LIDAR-based Urban Road and Sidewalk detection for Autonomous Vehicles 🚗

urban_road_filter: a real-time LIDAR-based urban road and sidewalk detection algorithm for autonomous vehicles Dependency ROS (tested with Kinetic and

JKK - Vehicle Industry Research Center 180 Dec 12, 2022
In this project, we develop a face recognize platform based on MTCNN object-detection netcwork and FaceNet self-supervised network.

模式识别大作业——人脸检测与识别平台 本项目是一个简易的人脸检测识别平台,提供了人脸信息录入和人脸识别的功能。前端采用 html+css+js,后端采用 pytorch,

Xuhua Huang 5 Aug 02, 2022
Automatic Attendance marker for LMS Practice School Division, BITS Pilani

LMS Attendance Marker Automatic script for lazy people to mark attendance on LMS for Practice School 1. Setup Add your LMS credentials and time slot t

Nihar Bansal 3 Jun 12, 2021
Official PyTorch implementation of Segmenter: Transformer for Semantic Segmentation

Segmenter: Transformer for Semantic Segmentation Segmenter: Transformer for Semantic Segmentation by Robin Strudel*, Ricardo Garcia*, Ivan Laptev and

594 Jan 06, 2023
AI Summer's complete catalog of articles

Learn Deep Learning with AI Summer A collection of all articles (almost 100) written for the AI Summer blog organized by topic. Deep Learning Theory M

AI Summer 95 Dec 29, 2022
Angular & Electron desktop UI framework. Angular components for native looking and behaving macOS desktop UI (Electron/Web)

Angular Desktop UI This is a collection for native desktop like user interface components in Angular, especially useful for Electron apps. It starts w

Marc J. Schmidt 49 Dec 22, 2022
Unofficial PyTorch Implementation for HifiFace (https://arxiv.org/abs/2106.09965)

HifiFace — Unofficial Pytorch Implementation Image source: HifiFace: 3D Shape and Semantic Prior Guided High Fidelity Face Swapping (figure 1, pg. 1)

MINDs Lab 218 Jan 04, 2023