[ICLR'21] Counterfactual Generative Networks

Overview

Counterfactual Generative Networks

[Project] [PDF] [Blog] [Music Video] [Colab]

This repository contains the code for the ICLR 2021 paper "Counterfactual Generative Networks" by Axel Sauer and Andreas Geiger. If you want to take the CGN for a spin and generate counterfactual images, you can try out the Colab below.

CGN

If you find our code or paper useful, please cite

@inproceedings{Sauer2021ICLR,
 author =  {Axel Sauer, Andreas Geiger},
 title = {Counterfactual Generative Networks},
 booktitle = {International Conference on Learning Representations (ICLR)},
 year = {2021}}

Setup

Install anaconda (if you don't have it yet)

wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh
bash Anaconda3-2020.11-Linux-x86_64.sh
source ~/.profile

Clone the repo and build the environment

git clone https://github.com/autonomousvision/counterfactual_generative_networks
cd counterfactual_generative_networks
conda env create -f environment.yml
conda activate cgn

Make all scripts executable: chmod +x scripts/*. Then, download the datasets (colored MNIST, Cue-Conflict, IN-9) and the pre-trained weights (CGN, U2-Net). Comment out the ones you don't need.

./scripts/download_data.sh
./scripts/download_weights.sh

MNISTs

The main functions of this sub-repo are:

  • Generating the MNIST variants
  • Training a CGN
  • Generating counterfactual datasets
  • Training a shape classifier

Train the CGN

We provide well-working configs and weights in mnists/experiments. To train a CGN on, e.g., Wildlife MNIST, run

python mnists/train_cgn.py --cfg mnists/experiments/cgn_wildlife_MNIST/cfg.yaml

For more info, add --help. Weights and samples will be saved in mnists/experiments/.

Generate Counterfactual Data

To generate the counterfactuals for, e.g., double-colored MNIST, run

python mnists/generate_data.py \
--weight_path mnists/experiments/cgn_double_colored_MNIST/weights/ckp.pth \
--dataset double_colored_MNIST --no_cfs 10 --dataset_size 100000

Make sure that you provide the right dataset together with the weights. You can adapt the weight-path to use your own weights. The command above generates ten counterfactuals per shape.

Train the Invariant Classifier

The classifier training uses Tensor datasets, so you need to save the non-counterfactual datasets as tensors. For DATASET = {colored_MNIST, double_colored_MNIST, wildlife_MNIST}, run

python mnists/generate_data.py --dataset DATASET

To train, e.g., a shape classifier (invariant to foreground and background) on wildlife MNIST, run,

python mnists/train_classifier.py --dataset wildlife_MNIST_counterfactual

Add --help for info on the available options and arguments. The hyperparameters are unchanged for all experiments.

ImageNet

The main functions of this sub-repo are:

  • Training a CGN
  • Generating data (samples, interpolations, or a whole dataset)
  • Training an invariant classifier ensemble

Train the CGN

Run

python imagenet/train_cgn.py --model_name MODEL_NAME

The default parameters should give you satisfactory results. You can change them in imagenet/config.yml. For more info, add --help. Weights and samples will be saved in imagenet/data/MODEL_NAME.

Generate Counterfactual Data

Samples. To generate a dataset of counterfactual images, run

python imagenet/generate_data.py --mode random --weights_path imagenet/weights/cgn.pth \
--n_data 100 --weights_path imagenet/weights/cgn.pth --run_name RUN_NAME \
--truncation 0.5 --batch_sz 1

The results will be saved in imagenet/data. For more info, add --help. If you want to save only masks, textures, etc., you need to change this directly in the code (see line 206).

The labels will be stored in a csv file. You can read them as follows:

import pandas as pd
df = pd.read_csv(path, index_col=0)
df = df.set_index('im_name')
shape_cls = df['shape_cls']['RUN_NAME_0000000.png']

Generating a dataset to train a classfier. Produce one dataset with --run_name train, the other one with --run_name val. If you have several GPUs available, you can index the name, e.g., --run_name train_GPU_NUM. The class ImagenetCounterfactual will glob all these datasets and generate a single, big training set. Make sure to set --batch_sz 1. With a larger batch size, a batch will be saved as a single png; this is useful for visualization, not for training.

Interpolations. To generate interpolation sheets, e.g., from a barn (425) to whale (147), run

python imagenet/generate_data.py --mode fixed_classes \
--n_data 1 --weights_path imagenet/weights/cgn.pth --run_name barn_to_whale \
--truncation 0.3 --interp all --classes 425 425 425 --interp_cls 147 --save_noise

You can also do counterfactual interpolations, i.e., interpolating only over, e.g., shape, by setting --interp shape.

Interpolation Gif. To generate a gif like in the teaser (Sample an image of class $1, than interpolate to shape $2, then background $3, then shape $4, and finally back to $1), run

./scripts/generate_teaser_gif.sh 992 293 147 330

The positional arguments are the classes, see imagenet labels for the available options.

Train the Invariant Classifier Ensemble

Training. First, you need to make sure that you have all datasets in imagenet/data/. Download Imagenet, e.g., from Kaggle, produce a counterfactual dataset (see above), and download the Cue-Conflict and BG-Challenge dataset (via the download script in scripts).

To train a classifier on a single GPU with a pre-trained Resnet-50 backbone, run

python imagenet/train_classifier.py -a resnet50 -b 32 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME

Again, add --help for more information on the possible arguments.

Distributed Training. To switch to multi-GPU training, run echo $CUDA_VISIBLE_DEVICES to see if the GPUs are visible. In the case of a single node with several GPUs, you can run, e.g.,

python imagenet/train_classifier.py -a resnet50 -b 256 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data CF_DATA_PATH --name RUN_NAME \
--rank 0 --multiprocessing-distributed --dist-url tcp://127.0.0.1:8890 --world-size 1

If your setup differs, e.g., several GPU machines, you need to adapt the rank and world size.

Visualization. To visualize the Tensorboard outputs, run tensorboard --logdir=imagenet/runs and open the local address in your browser.

Acknowledgments

We like to acknowledge several repos of which we use parts of code, data, or models in our implementation:

The official implementation of Variable-Length Piano Infilling (VLI).

Variable-Length-Piano-Infilling The official implementation of Variable-Length Piano Infilling (VLI). (paper: Variable-Length Music Score Infilling vi

29 Sep 01, 2022
LRBoost is a scikit-learn compatible approach to performing linear residual based stacking/boosting.

LRBoost is a sckit-learn compatible package for linear residual boosting. LRBoost combines a linear estimator and a non-linear estimator to leverage t

Andrew Patton 5 Nov 23, 2022
DRIFT is a tool for Diachronic Analysis of Scientific Literature.

About DRIFT is a tool for Diachronic Analysis of Scientific Literature. The application offers user-friendly and customizable utilities for two modes:

Rajaswa Patil 108 Dec 12, 2022
Visualization toolkit for neural networks in PyTorch! Demo -->

FlashTorch A Python visualization toolkit, built with PyTorch, for neural networks in PyTorch. Neural networks are often described as "black box". The

Misa Ogura 692 Dec 29, 2022
ML powered analytics engine for outlier detection and root cause analysis.

Website • Docs • Blog • LinkedIn • Community Slack ML powered analytics engine for outlier detection and root cause analysis ✨ What is Chaos Genius? C

Chaos Genius 523 Jan 04, 2023
All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

Daniel Bourke 3.4k Jan 07, 2023
Code for Transformer Hawkes Process, ICML 2020.

Transformer Hawkes Process Source code for Transformer Hawkes Process (ICML 2020). Run the code Dependencies Python 3.7. Anaconda contains all the req

Simiao Zuo 111 Dec 26, 2022
Extremely simple and fast extreme multi-class and multi-label classifiers.

napkinXC napkinXC is an extremely simple and fast library for extreme multi-class and multi-label classification, that focus of implementing various m

Marek Wydmuch 43 Nov 14, 2022
VOLO: Vision Outlooker for Visual Recognition

VOLO: Vision Outlooker for Visual Recognition, arxiv This is a PyTorch implementation of our paper. We present Vision Outlooker (VOLO). We show that o

Sea AI Lab 876 Dec 09, 2022
Extracting knowledge graphs from language models as a diagnostic benchmark of model performance.

Interpreting Language Models Through Knowledge Graph Extraction Idea: How do we interpret what a language model learns at various stages of training?

EPFL Machine Learning and Optimization Laboratory 9 Oct 25, 2022
Generate pixel-style avatars with python.

face2pixel Generate pixel-style avatars with python. Run: Clone the project: git clone https://github.com/theodorecooper/face2pixel install requiremen

Theodore Cooper 2 May 11, 2022
[NeurIPS 2021] The PyTorch implementation of paper "Self-Supervised Learning Disentangled Group Representation as Feature"

IP-IRM [NeurIPS 2021] The PyTorch implementation of paper "Self-Supervised Learning Disentangled Group Representation as Feature". Codes will be relea

Wang Tan 67 Dec 24, 2022
ElasticFace: Elastic Margin Loss for Deep Face Recognition

This is the official repository of the paper: ElasticFace: Elastic Margin Loss for Deep Face Recognition Paper on arxiv: arxiv Model Log file Pretrain

Fadi Boutros 113 Dec 14, 2022
An official implementation of "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation" (CVPR 2021) in PyTorch.

BANA This is the implementation of the paper "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation". For more inf

CV Lab @ Yonsei University 59 Dec 12, 2022
Official Python implementation of the FuzionCoin protocol

PyFuzc Official Python implementation of the FuzionCoin protocol WARNING: Under construction. Use at your own risk. Some functions may not work. Setup

FuzionCoin 3 Jul 07, 2022
Wordle Env: A Daily Word Environment for Reinforcement Learning

Wordle Env: A Daily Word Environment for Reinforcement Learning Setup Steps: git pull [email&#

2 Mar 28, 2022
Automatic Video Captioning Evaluation Metric --- EMScore

Automatic Video Captioning Evaluation Metric --- EMScore Overview For an illustration, EMScore can be computed as: Installation modify the encode_text

Yaya Shi 17 Nov 28, 2022
Code for NeurIPS2021 submission "A Surrogate Objective Framework for Prediction+Programming with Soft Constraints"

This repository is the code for NeurIPS 2021 submission "A Surrogate Objective Framework for Prediction+Programming with Soft Constraints". Edit 2021/

10 Dec 20, 2022
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

64 Jan 05, 2023