[ICLR'21] Counterfactual Generative Networks

Overview

Counterfactual Generative Networks

[Project] [PDF] [Blog] [Music Video] [Colab]

This repository contains the code for the ICLR 2021 paper "Counterfactual Generative Networks" by Axel Sauer and Andreas Geiger. If you want to take the CGN for a spin and generate counterfactual images, you can try out the Colab below.

CGN

If you find our code or paper useful, please cite

@inproceedings{Sauer2021ICLR,
 author =  {Axel Sauer, Andreas Geiger},
 title = {Counterfactual Generative Networks},
 booktitle = {International Conference on Learning Representations (ICLR)},
 year = {2021}}

Setup

Install anaconda (if you don't have it yet)

wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh
bash Anaconda3-2020.11-Linux-x86_64.sh
source ~/.profile

Clone the repo and build the environment

git clone https://github.com/autonomousvision/counterfactual_generative_networks
cd counterfactual_generative_networks
conda env create -f environment.yml
conda activate cgn

Make all scripts executable: chmod +x scripts/*. Then, download the datasets (colored MNIST, Cue-Conflict, IN-9) and the pre-trained weights (CGN, U2-Net). Comment out the ones you don't need.

./scripts/download_data.sh
./scripts/download_weights.sh

MNISTs

The main functions of this sub-repo are:

  • Generating the MNIST variants
  • Training a CGN
  • Generating counterfactual datasets
  • Training a shape classifier

Train the CGN

We provide well-working configs and weights in mnists/experiments. To train a CGN on, e.g., Wildlife MNIST, run

python mnists/train_cgn.py --cfg mnists/experiments/cgn_wildlife_MNIST/cfg.yaml

For more info, add --help. Weights and samples will be saved in mnists/experiments/.

Generate Counterfactual Data

To generate the counterfactuals for, e.g., double-colored MNIST, run

python mnists/generate_data.py \
--weight_path mnists/experiments/cgn_double_colored_MNIST/weights/ckp.pth \
--dataset double_colored_MNIST --no_cfs 10 --dataset_size 100000

Make sure that you provide the right dataset together with the weights. You can adapt the weight-path to use your own weights. The command above generates ten counterfactuals per shape.

Train the Invariant Classifier

The classifier training uses Tensor datasets, so you need to save the non-counterfactual datasets as tensors. For DATASET = {colored_MNIST, double_colored_MNIST, wildlife_MNIST}, run

python mnists/generate_data.py --dataset DATASET

To train, e.g., a shape classifier (invariant to foreground and background) on wildlife MNIST, run,

python mnists/train_classifier.py --dataset wildlife_MNIST_counterfactual

Add --help for info on the available options and arguments. The hyperparameters are unchanged for all experiments.

ImageNet

The main functions of this sub-repo are:

  • Training a CGN
  • Generating data (samples, interpolations, or a whole dataset)
  • Training an invariant classifier ensemble

Train the CGN

Run

python imagenet/train_cgn.py --model_name MODEL_NAME

The default parameters should give you satisfactory results. You can change them in imagenet/config.yml. For more info, add --help. Weights and samples will be saved in imagenet/data/MODEL_NAME.

Generate Counterfactual Data

Samples. To generate a dataset of counterfactual images, run

python imagenet/generate_data.py --mode random --weights_path imagenet/weights/cgn.pth \
--n_data 100 --weights_path imagenet/weights/cgn.pth --run_name RUN_NAME \
--truncation 0.5 --batch_sz 1

The results will be saved in imagenet/data. For more info, add --help. If you want to save only masks, textures, etc., you need to change this directly in the code (see line 206).

The labels will be stored in a csv file. You can read them as follows:

import pandas as pd
df = pd.read_csv(path, index_col=0)
df = df.set_index('im_name')
shape_cls = df['shape_cls']['RUN_NAME_0000000.png']

Generating a dataset to train a classfier. Produce one dataset with --run_name train, the other one with --run_name val. If you have several GPUs available, you can index the name, e.g., --run_name train_GPU_NUM. The class ImagenetCounterfactual will glob all these datasets and generate a single, big training set. Make sure to set --batch_sz 1. With a larger batch size, a batch will be saved as a single png; this is useful for visualization, not for training.

Interpolations. To generate interpolation sheets, e.g., from a barn (425) to whale (147), run

python imagenet/generate_data.py --mode fixed_classes \
--n_data 1 --weights_path imagenet/weights/cgn.pth --run_name barn_to_whale \
--truncation 0.3 --interp all --classes 425 425 425 --interp_cls 147 --save_noise

You can also do counterfactual interpolations, i.e., interpolating only over, e.g., shape, by setting --interp shape.

Interpolation Gif. To generate a gif like in the teaser (Sample an image of class $1, than interpolate to shape $2, then background $3, then shape $4, and finally back to $1), run

./scripts/generate_teaser_gif.sh 992 293 147 330

The positional arguments are the classes, see imagenet labels for the available options.

Train the Invariant Classifier Ensemble

Training. First, you need to make sure that you have all datasets in imagenet/data/. Download Imagenet, e.g., from Kaggle, produce a counterfactual dataset (see above), and download the Cue-Conflict and BG-Challenge dataset (via the download script in scripts).

To train a classifier on a single GPU with a pre-trained Resnet-50 backbone, run

python imagenet/train_classifier.py -a resnet50 -b 32 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME

Again, add --help for more information on the possible arguments.

Distributed Training. To switch to multi-GPU training, run echo $CUDA_VISIBLE_DEVICES to see if the GPUs are visible. In the case of a single node with several GPUs, you can run, e.g.,

python imagenet/train_classifier.py -a resnet50 -b 256 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME \
--rank 0 --multiprocessing-distributed --dist-url tcp://127.0.0.1:8890 --world-size 1

If your setup differs, e.g., several GPU machines, you need to adapt the rank and world size.

Visualization. To visualize the Tensorboard outputs, run tensorboard --logdir=imagenet/runs and open the local address in your browser.

Acknowledgments

We like to acknowledge several repos of which we use parts of code, data, or models in our implementation:

FIGARO: Generating Symbolic Music with Fine-Grained Artistic Control

FIGARO: Generating Symbolic Music with Fine-Grained Artistic Control by Dimitri von Rütte, Luca Biggio, Yannic Kilcher, Thomas Hofmann FIGARO: Generat

Dimitri 83 Jan 07, 2023
Official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Recognition" in AAAI2022.

AimCLR This is an official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Reco

Gty 44 Dec 17, 2022
Multi-Output Gaussian Process Toolkit

Multi-Output Gaussian Process Toolkit Paper - API Documentation - Tutorials & Examples The Multi-Output Gaussian Process Toolkit is a Python toolkit f

GAMES 113 Nov 25, 2022
LightLog is an open source deep learning based lightweight log analysis tool for log anomaly detection.

LightLog Introduction LightLog is an open source deep learning based lightweight log analysis tool for log anomaly detection. Function description [BG

25 Dec 17, 2022
PyTorch code for our paper "Image Super-Resolution with Non-Local Sparse Attention" (CVPR2021).

Image Super-Resolution with Non-Local Sparse Attention This repository is for NLSN introduced in the following paper "Image Super-Resolution with Non-

143 Dec 28, 2022
Source code for our paper "Empathetic Response Generation with State Management"

Source code for our paper "Empathetic Response Generation with State Management" this repository is maintained by both Jun Gao and Yuhan Liu Model Ove

Yuhan Liu 3 Oct 08, 2022
Code for the ACL2021 paper "Lexicon Enhanced Chinese Sequence Labelling Using BERT Adapter"

Lexicon Enhanced Chinese Sequence Labeling Using BERT Adapter Code and checkpoints for the ACL2021 paper "Lexicon Enhanced Chinese Sequence Labelling

274 Dec 06, 2022
Detectron2 is FAIR's next-generation platform for object detection and segmentation.

Detectron2 is Facebook AI Research's next generation software system that implements state-of-the-art object detection algorithms. It is a ground-up r

Facebook Research 23.3k Jan 08, 2023
LyaNet: A Lyapunov Framework for Training Neural ODEs

LyaNet: A Lyapunov Framework for Training Neural ODEs Provide the model type--config-name to train and test models configured as those shown in the pa

Ivan Dario Jimenez Rodriguez 21 Nov 21, 2022
This is a simple framework to make object detection dataset very quickly

FastAnnotation Table of contents General info Requirements Setup General info This is a simple framework to make object detection dataset very quickly

Serena Tetart 1 Jan 24, 2022
[CVPR 2021] NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning

NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning Project Page | Paper | Supplemental material #1 | Supplement

KAIST VCLAB 49 Nov 24, 2022
Official implementation of "OpenPifPaf: Composite Fields for Semantic Keypoint Detection and Spatio-Temporal Association" in PyTorch.

openpifpaf Continuously tested on Linux, MacOS and Windows: New 2021 paper: OpenPifPaf: Composite Fields for Semantic Keypoint Detection and Spatio-Te

VITA lab at EPFL 50 Dec 29, 2022
Official repository of the paper Privacy-friendly Synthetic Data for the Development of Face Morphing Attack Detectors

SMDD-Synthetic-Face-Morphing-Attack-Detection-Development-dataset Official repository of the paper Privacy-friendly Synthetic Data for the Development

10 Dec 12, 2022
HINet: Half Instance Normalization Network for Image Restoration

HINet: Half Instance Normalization Network for Image Restoration Liangyu Chen, Xin Lu, Jie Zhang, Xiaojie Chu, Chengpeng Chen Paper: https://arxiv.org

303 Dec 31, 2022
(3DV 2021 Oral) Filtering by Cluster Consistency for Large-Scale Multi-Image Matching

Scalable Cluster-Consistency Statistics for Robust Multi-Object Matching (3DV 2021 Oral Presentation) Filtering by Cluster Consistency (FCC) is a very

Yunpeng Shi 11 Sep 28, 2022
Discovering Interpretable GAN Controls [NeurIPS 2020]

GANSpace: Discovering Interpretable GAN Controls Figure 1: Sequences of image edits performed using control discovered with our method, applied to thr

Erik Härkönen 1.7k Jan 03, 2023
Unsupervised Image-to-Image Translation

UNIT: UNsupervised Image-to-image Translation Networks Imaginaire Repository We have a reimplementation of the UNIT method that is more performant. It

Ming-Yu Liu 劉洺堉 1.9k Dec 26, 2022
A tutorial on training a DarkNet YOLOv4 model for the CrowdHuman dataset

YOLOv4 CrowdHuman Tutorial This is a tutorial demonstrating how to train a YOLOv4 people detector using Darknet and the CrowdHuman dataset. Table of c

JK Jung 118 Nov 10, 2022
Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer

ConSERT Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer Requirements torch==1.6.0

Yan Yuanmeng 478 Dec 25, 2022