A new play-and-plug method of controlling an existing generative model with conditioning attributes and their compositions.

Related tags

Deep LearningLACE
Overview

Controllable and Compositional Generation with Latent-Space Energy-Based Models

Python 3.8 pytorch 1.7.1 Torchdiffeq 0.2.1

Teaser image Teaser image

Official PyTorch implementation of the NeurIPS 2021 paper:
Controllable and Compositional Generation with Latent-Space Energy-Based Models
Weili Nie, Arash Vahdat, Anima Anandkumar
https://nvlabs.github.io/LACE

Abstract: Controllable generation is one of the key requirements for successful adoption of deep generative models in real-world applications, but it still remains as a great challenge. In particular, the compositional ability to generate novel concept combinations is out of reach for most current models. In this work, we use energy-based models (EBMs) to handle compositional generation over a set of attributes. To make them scalable to high-resolution image generation, we introduce an EBM in the latent space of a pre-trained generative model such as StyleGAN. We propose a novel EBM formulation representing the joint distribution of data and attributes together, and we show how sampling from it is formulated as solving an ordinary differential equation (ODE). Given a pre-trained generator, all we need for controllable generation is to train an attribute classifier. Sampling with ODEs is done efficiently in the latent space and is robust to hyperparameters. Thus, our method is simple, fast to train, and efficient to sample. Experimental results show that our method outperforms the state-of-the-art in both conditional sampling and sequential editing. In compositional generation, our method excels at zero-shot generation of unseen attribute combinations. Also, by composing energy functions with logical operators, this work is the first to achieve such compositionality in generating photo-realistic images of resolution 1024x1024.

Requirements

  • Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
  • 1 high-end NVIDIA GPU with at least 24 GB of memory. We have done all testing and development using a single NVIDIA V100 GPU with memory size 32 GB.
  • 64-bit Python 3.8.
  • CUDA=10.0 and docker must be installed first.
  • Installation of the required library dependencies with Docker:
    docker build -f lace-cuda-10p0.Dockerfile --tag=lace-cuda-10-0:0.0.1 .
    docker run -it -d --gpus 0 --name lace --shm-size 8G -v $(pwd):/workspace -p 5001:6006 lace-cuda-10-0:0.0.1
    docker exec -it lace bash

Experiments on CIFAR-10

The CIFAR10 folder contains the codebase to get the main results on the CIFAR-10 dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., the latent code and label pairs) from here and unzip it to the CIFAR10 folder. Or you can go to the folder CIFAR10/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

In the script run_clf.sh, the variable x can be specified to w or z, representing that the latent classifier is trained in the w-space or z-space of StyleGAN, respectively.

Sampling

To get the conditional sampling results with the ODE or Langevin dynamics (LD) sampler, you can run:

# ODE
bash scripts/run_cond_ode_sample.sh

# LD
bash scripts/run_cond_ld_sample.sh

By default, we set x to w, meaning we use the w-space classifier, because we find our method works the best in w-space. You can change the value of x to z or i to use the classifier in z-space or pixel space, for a comparison.

To compute the conditional accuracy (ACC) and FID scores in conditional sampling with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_score.sh

# LD
bash scripts/run_cond_ld_score.sh

Note that:

  1. For the ACC evaluation, you need a pre-trained image classifier, which can be downloaded as instructed here;

  2. For the FID evaluation, you need to have the FID reference statistics computed beforehand. You can go to the folder CIFAR10/prepare_data and follow the instructions to compute the FID reference statistics with real images sampled from CIFAR-10.

Experiments on FFHQ

The FFHQ folder contains the codebase for getting the main results on the FFHQ dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here (originally from StyleFlow) and unzip it to the FFHQ folder.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., glasses) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

First, you have to get the pre-trained StyleGAN2 (config-f) by following the instructions in Convert StyleGAN2 weight from official checkpoints.

Conditional sampling

To get the conditional sampling results with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_sample.sh

# LD
bash scripts/run_cond_ld_sample.sh

To compute the conditional accuracy (ACC) and FID scores in conditional sampling with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_score.sh

# LD
bash scripts/run_cond_ld_score.sh

Note that:

  1. For the ACC evaluation, you need to train an FFHQ image classifier, as instructed here;

  2. For the FID evaluation, you need to have the FID reference statistics computed beforehand. You can go to the folder FFHQ/prepare_models_data and follow the instructions to compute the FID reference statistics with the StyleGAN generated FFHQ images.

Sequential editing

To get the qualitative and quantitative results of sequential editing, you can run:

# User-specified sampling
bash scripts/run_seq_edit_sample.sh

# ACC and FID
bash scripts/run_seq_edit_score.sh

Note that:

  • Similarly, you first need to train an FFHQ image classifier and get the FID reference statics to compute ACC and FID score by following the instructions, respectively.

  • To get the face identity preservation (ID) score, you first need to download the pre-trained ArcFace network, which is publicly available here, to the folder FFHQ/pretrained/metrics.

Compositional Generation

To get the results of zero-shot generation on novel attribute combinations, you can run:

bash scripts/run_zero_shot.sh

To get the results of compositions of energy functions with logical operators, we run:

bash scripts/run_combine_energy.sh

Experiments on MetFaces

The MetFaces folder contains the codebase for getting the main results on the MetFaces dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here and unzip it to the MetFaces folder. Or you can go to the folder MetFaces/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., yaw) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

To get the conditional sampling and sequential editing results, you can run:

# conditional sampling
bash scripts/run_cond_sample.sh

# sequential editing
bash scripts/run_seq_edit_sample.sh

Experiments on AFHQ-Cats

The AFHQ folder contains the codebase for getting the main results on the AFHQ-Cats dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here and unzip it to the AFHQ folder. Or you can go to the folder AFHQ/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., breeds) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

To get the conditional sampling and sequential editing results, you can run:

# conditional sampling
bash scripts/run_cond_sample.sh

# sequential editing
bash scripts/run_seq_edit_sample.sh

License

Please check the LICENSE file. This work may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].

Citation

Please cite our paper, if you happen to use this codebase:

@inproceedings{nie2021controllable,
  title={Controllable and compositional generation with latent-space energy-based models},
  author={Nie, Weili and Vahdat, Arash and Anandkumar, Anima},
  booktitle={Neural Information Processing Systems (NeurIPS)},
  year={2021}
}
Owner
NVIDIA Research Projects
NVIDIA Research Projects
SOTA easy to use PyTorch-based DL training library

Easily train or fine-tune SOTA computer vision models from one training repository. SuperGradients Introduction Welcome to SuperGradients, a free open

619 Jan 03, 2023
Official PyTorch implementation of RIO

Image-Level or Object-Level? A Tale of Two Resampling Strategies for Long-Tailed Detection Figure 1: Our proposed Resampling at image-level and obect-

NVIDIA Research Projects 17 May 20, 2022
Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences

Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences This repository is an official PyTorch implementation of Neighbor

DIVE Lab, Texas A&M University 8 Jun 12, 2022
An implementation of the BADGE batch active learning algorithm.

Batch Active learning by Diverse Gradient Embeddings (BADGE) An implementation of the BADGE batch active learning algorithm. Details are provided in o

125 Dec 24, 2022
Modeling Temporal Concept Receptive Field Dynamically for Untrimmed Video Analysis

Modeling Temporal Concept Receptive Field Dynamically for Untrimmed Video Analysis This is a PyTorch implementation of the model described in our pape

qzhb 6 Jul 08, 2021
Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Chris Donahue 98 Dec 14, 2022
Light-Head R-CNN

Light-head R-CNN Introduction We release code for Light-Head R-CNN. This is my best practice for my research. This repo is organized as follows: light

jemmy li 835 Dec 06, 2022
System Design course at HSE (2021)

System Design course at HSE (2021) Wiki-страница курса Структура репозитория: slides - директория с презентациями с занятий tasks - материалы для выпо

22 Dec 25, 2022
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Geonmo Gu 3 Jun 09, 2021
Geometric Vector Perceptron --- a rotation-equivariant GNN for learning from biomolecular structure

Geometric Vector Perceptron Code to accompany Learning from Protein Structure with Geometric Vector Perceptrons by B Jing, S Eismann, P Suriana, RJL T

Dror Lab 85 Dec 29, 2022
A dataset for online Arabic calligraphy

Calliar Calliar is a dataset for Arabic calligraphy. The dataset consists of 2500 json files that contain strokes manually annotated for Arabic callig

ARBML 114 Dec 28, 2022
An intelligent, flexible grammar of machine learning.

An english representation of machine learning. Modify what you want, let us handle the rest. Overview Nylon is a python library that lets you customiz

Palash Shah 79 Dec 02, 2022
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

223 Dec 17, 2022
Pytorch code for paper "Image Compressed Sensing Using Non-local Neural Network" TMM 2021.

NL-CSNet-Pytorch Pytorch code for paper "Image Compressed Sensing Using Non-local Neural Network" TMM 2021. Note: this repo only shows the strategy of

WenxueCui 7 Nov 07, 2022
Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding

🍐 quince Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding 🍐 Installation $ git clone

Andrew Jesson 19 Jun 23, 2022
A transformer model to predict pathogenic mutations

MutFormer MutFormer is an application of the BERT (Bidirectional Encoder Representations from Transformers) NLP (Natural Language Processing) model wi

Wang Genomics Lab 2 Nov 29, 2022
Codes and pretrained weights for winning submission of 2021 Brain Tumor Segmentation (BraTS) Challenge

Winning submission to the 2021 Brain Tumor Segmentation Challenge This repo contains the codes and pretrained weights for the winning submission to th

94 Dec 28, 2022
CurriculumNet: Weakly Supervised Learning from Large-Scale Web Images

CurriculumNet Introduction This repo contains related code and models from the ECCV 2018 CurriculumNet paper. CurriculumNet is a new training strategy

156 Jul 04, 2022
A Repository of Community-Driven Natural Instructions

A Repository of Community-Driven Natural Instructions TLDR; this repository maintains a community effort to create a large collection of tasks and the

AI2 244 Jan 04, 2023
Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image

Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image This repository is an implementation of the method described in the following pap

21 Dec 15, 2022