PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot

Overview

Progressive Growing of GANs inference in PyTorch with CelebA training snapshot

Description

This is an inference sample written in PyTorch of the original Theano/Lasagne code.

I recreated the network as described in the paper of Karras et al. Since some layers seemed to be missing in PyTorch, these were implemented as well. The network and the layers can be found in model.py.

For the demo, a 100-celeb-hq-1024x1024-ours snapshot was used, which was made publicly available by the authors. Since I couldn't find any model converter between Theano/Lasagne and PyTorch, I used a quick and dirty script to transfer the weights between the models (transfer_weights.py).

This repo does not provide the code for training the networks.

Simple inference

To run the demo, simply execute predict.py. You can specify other weights with the --weights flag.

Example image:

Example image

Latent space interpolation

To try the latent space interpolation, use latent_interp.py. All output images will be saved in ./interp.

You can chose between the "gaussian interpolation" introduced in the original paper and the "slerp interpolation" introduced by Tom White in his paper Sampling Generative Networks using the --type argument.

Use --filter to change the gaussian filter size for the gaussian interpolation and --interp for the interpolation steps for the slerp interpolation.

The following arguments are defined:

  • --weights - path to pretrained PyTorch state dict
  • --output - Directory for storing interpolated images
  • --batch_size - batch size for DataLoader
  • --num_workers - number of workers for DataLoader
  • --type {gauss, slerp} - interpolation type
  • --nb_latents - number of latent vectors to generate
  • --filter - gaussian filter length for interpolating latent space (gauss interpolation)
  • --interp - interpolation length between each latent vector (slerp interpolation)
  • --seed - random seed for numpy and PyTorch
  • --cuda - use GPU

The total number of generated frames depends on the used interpolation technique.

For gaussian interpolation the number of generated frames equals nb_latents, while the slerp interpolation generates nb_latents * interp frames.

Example interpolation:

Example interpolation

Live latent space interpolation

A live demo of the latent space interpolation using PyGame can be seen in pygame_interp_demo.py.

Use the --size argument to change the output window size.

The following arguments are defined:

  • --weights - path to pretrained PyTorch state dict
  • --num_workers - number of workers for DataLoader
  • --type {gauss, slerp} - interpolation type
  • --nb_latents - number of latent vectors to generate
  • --filter - gaussian filter length for interpolating latent space (gauss interpolation)
  • --interp - interpolation length between each latent vector (slerp interpolation)
  • --size - PyGame window size
  • --seed - random seed for numpy and PyTorch
  • --cuda - use GPU

Transferring weights

The pretrained lasagne weights can be transferred to a PyTorch state dict using transfer_weights.py.

To transfer other snapshots from the paper (other than CelebA), you have to modify the model architecture accordingly and use the corresponding weights.

Environment

The code was tested on Ubuntu 16.04 with an NVIDIA GTX 1080 using PyTorch v.0.2.0_4.

  • transfer_weights.py needs Theano and Lasagne to load the pretrained weights.
  • pygame_interp_demo.py needs PyGame to visualize the output

A single forward pass took approx. 0.031 seconds.

Links

License

This code is a modified form of the original code under the CC BY-NC license with the following copyright notice:

# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.

According the Section 3, I hereby identify Tero Karras et al. and NVIDIA as the original authors of the material.

Owner
Deep Learning Frameworks @NVIDIA
The source code of CVPR17 'Generative Face Completion'.

GenerativeFaceCompletion Matcaffe implementation of our CVPR17 paper on face completion. In each panel from left to right: original face, masked input

Yijun Li 313 Oct 18, 2022
Label Hallucination for Few-Shot Classification

Label Hallucination for Few-Shot Classification This repo covers the implementation of the following paper: Label Hallucination for Few-Shot Classific

Yiren Jian 13 Nov 13, 2022
TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffic Environments for IV 2022.

TorchGRL TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffi

XXQQ 42 Dec 09, 2022
RealTime Emotion Recognizer for Machine Learning Study Jam's demo

Emotion recognizer Table of contents Clone project Dataset Install dependencies Main program Demo 1. Clone project git clone https://github.com/GDSC20

Google Developer Student Club - UIT 1 Oct 05, 2021
Research using Cirq!

ReCirq Research using Cirq! This project contains modules for running quantum computing applications and experiments through Cirq and Quantum Engine.

quantumlib 230 Dec 29, 2022
This project deals with the detection of skin lesions within the ISICs dataset using YOLOv3 Object Detection with Darknet.

This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License. Skin Lesion detection using YOLO This project deal

Lalith Veerabhadrappa Badiger 1 Nov 22, 2021
Source code for Adaptively Calibrated Critic Estimates for Deep Reinforcement Learning

Adaptively Calibrated Critic Estimates for Deep Reinforcement Learning Official implementation of ACC, described in the paper "Adaptively Calibrated C

3 Sep 16, 2022
Repository of 3D Object Detection with Pointformer (CVPR2021)

3D Object Detection with Pointformer This repository contains the code for the paper 3D Object Detection with Pointformer (CVPR 2021) [arXiv]. This wo

Zhuofan Xia 117 Jan 06, 2023
Depth-Aware Video Frame Interpolation (CVPR 2019)

DAIN (Depth-Aware Video Frame Interpolation) Project | Paper Wenbo Bao, Wei-Sheng Lai, Chao Ma, Xiaoyun Zhang, Zhiyong Gao, and Ming-Hsuan Yang IEEE C

Wenbo Bao 7.7k Dec 31, 2022
TakeInfoatNistforICS - Take Information in NIST NVD for ICS

Take Information in NIST NVD for ICS This project developed with Python. When yo

5 Sep 05, 2022
Official implementation of the paper Visual Parser: Representing Part-whole Hierarchies with Transformers

Visual Parser (ViP) This is the official implementation of the paper Visual Parser: Representing Part-whole Hierarchies with Transformers. Key Feature

Shuyang Sun 117 Dec 11, 2022
A script depending on VASP output for calculating Fermi-Softness.

Fermi softness calculation for Vienna Ab initio Simulation Package (VASP) Update 1.1.0: Big update: Rewrote the code. Use Bader atomic division instea

qslin 11 Nov 08, 2022
基于tensorflow 2.x的图片识别工具集

Classification.tf2 基于tensorflow 2.x的图片识别工具集 功能 粗粒度场景图片分类 细粒度场景图片分类 其他场景图片分类 模型部署 tensorflow serving本地推理和docker部署 tensorRT onnx ... 数据集 https://hyper.a

Wei Qi 1 Nov 03, 2021
Code for Deep Single-image Portrait Image Relighting

Deep Single-Image Portrait Relighting [Project Page] Hao Zhou, Sunil Hadap, Kalyan Sunkavalli, David W. Jacobs. In ICCV, 2019 Overview Test script for

438 Jan 05, 2023
scikit-learn: machine learning in Python

scikit-learn is a Python module for machine learning built on top of SciPy and is distributed under the 3-Clause BSD license. The project was started

scikit-learn 52.5k Jan 08, 2023
[v1 (ISBI'21) + v2] MedMNIST: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification

MedMNIST Project (Website) | Dataset (Zenodo) | Paper (arXiv) | MedMNIST v1 (ISBI'21) Jiancheng Yang, Rui Shi, Donglai Wei, Zequan Liu, Lin Zhao, Bili

683 Dec 28, 2022
Object recognition using Azure Custom Vision AI and Azure Functions

Step by Step on how to create an object recognition model using Custom Vision, export the model and run the model in an Azure Function

El Bruno 11 Jul 08, 2022
Official repository for "Intriguing Properties of Vision Transformers" (2021)

Intriguing Properties of Vision Transformers Muzammal Naseer, Kanchana Ranasinghe, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, & Ming-Hsuan Yang P

Muzammal Naseer 155 Dec 27, 2022
ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Katherine Crowson 53 Dec 29, 2022
Fashion Entity Classification

Fashion-Entity-Classification - Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grays

ADITYA SHAH 1 Jan 04, 2022