Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Overview

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage

This repository provides the official PyTorch implementation of the following paper:

Learning Debiased Representation via Disentangled Feature Augmentation
Jungsoo Lee* (KAIST AI, Kakao Enterprise), Eungyeup Kim* (KAIST AI, Kakao Enterprise),
Juyoung Lee (Kakao Enterprise), Jihyeon Lee (KAIST AI), and Jaegul Choo (KAIST AI)
(* indicates equal contribution. The order of first authors was chosen by tossing a coin.)
NeurIPS 2021, Oral

Paper: Arxiv

Abstract: Image classification models tend to make decisions based on peripheral attributes of data items that have strong correlation with a target variable (i.e., dataset bias). These biased models suffer from the poor generalization capability when evaluated on unbiased datasets. Existing approaches for debiasing often identify and emphasize those samples with no such correlation (i.e., bias-conflicting) without defining the bias type in advance. However, such bias-conflicting samples are significantly scarce in biased datasets, limiting the debiasing capability of these approaches. This paper first presents an empirical analysis revealing that training with "diverse" bias-conflicting samples beyond a given training set is crucial for debiasing as well as the generalization capability. Based on this observation, we propose a novel feature-level data augmentation technique in order to synthesize diverse bias-conflicting samples. To this end, our method learns the disentangled representation of (1) the intrinsic attributes (i.e., those inherently defining a certain class) and (2) bias attributes (i.e., peripheral attributes causing the bias), from a large number of bias-aligned samples, the bias attributes of which have strong correlation with the target variable. Using the disentangled representation, we synthesize bias-conflicting samples that contain the diverse intrinsic attributes of bias-aligned samples by swapping their latent features. By utilizing these diversified bias-conflicting features during the training, our approach achieves superior classification accuracy and debiasing results against the existing baselines on both synthetic as well as a real-world dataset.

Code Contributors

Jungsoo Lee [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Eungyeup Kim [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Juyoung Lee [Website] (Kakao Enterprise)

Pytorch Implementation

Installation

Clone this repository.

git clone https://github.com/kakaoenterprise/Learning-Debiased-Disentangled.git
cd Learning-Debiased-Disentangled
pip install -r requirements.txt

Datasets

We used three datasets in our paper.

Download the datasets with the following url. Note that BFFHQ is the dataset used in "BiaSwap: Removing Dataset Bias with Bias-Tailored Swapping Augmentation" (Kim et al., ICCV 2021). Unzip the files and the directory structures will be as following:

cmnist
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
cifar10c
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
bffhq
 └ 0.5pct
 └ valid
 └ test

How to Run

CMNIST

Vanilla
python train.py --dataset cmnist --exp=cmnist_0.5_vanilla --lr=0.01 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_vanilla --lr=0.01 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_vanilla --lr=0.01 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_vanilla --lr=0.01 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cmnist_vanilla.sh
Ours
python train.py --dataset cmnist --exp=cmnist_0.5_ours --lr=0.01 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_ours --lr=0.01 --percent=1pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_ours --lr=0.01 --percent=2pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_ours --lr=0.01 --percent=5pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cmnist_ours.sh

Corrupted CIFAR10

Vanilla
python train.py --dataset cifar10c --exp=cifar10c_0.5_vanilla --lr=0.001 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_vanilla --lr=0.001 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_vanilla --lr=0.001 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_vanilla --lr=0.001 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cifar10c_vanilla.sh
Ours
python train.py --dataset cifar10c --exp=cifar10c_0.5_ours --lr=0.0005 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_ours --lr=0.001 --percent=1pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_ours --lr=0.001 --percent=2pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_ours --lr=0.001 --percent=5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cifar10c_ours.sh

BFFHQ

Vanilla
python train.py --dataset bffhq --exp=bffhq_0.5_vanilla --lr=0.0001 --percent=0.5pct --train_vanilla --tensorboard --wandb
bash scripts/run_bffhq_vanilla.sh
Ours
python train.py --dataset bffhq --exp=bffhq_0.5_ours --lr=0.0001 --percent=0.5pct --lambda_swap=0.1 --curr_step=10000 --use_lr_decay --lr_decay_step=10000 --lambda_dis_align 2. --lambda_swap_align 2. --dataset bffhq --train_ours --tensorboard --wandb
bash scripts/run_bffhq_ours.sh

Pretrained Models

In order to test our pretrained models, run the following command.

python test.py --pretrained_path=
   
     --dataset=
    
      --percent=
     

     
    
   

We provide the pretrained models in the following urls.
CMNIST 0.5pct
CMNIST 1pct
CMNIST 2pct
CMNIST 5pct

CIFAR10C 0.5pct
CIFAR10C 1pct
CIFAR10C 2pct
CIFAR10C 5pct

BFFHQ 0.5pct

Citations

Bibtex coming soon!

Contact

Jungsoo Lee

Eungyeup Kim

Juyoung Lee

Kakao Enterprise/Vision Team

Acknowledgments

This work was mainly done when both of the first authors were doing internship at Vision Team/AI Lab/Kakao Enterprise. Our pytorch implementation is based on LfF. Thanks for the implementation.

Owner
Kakao Enterprise Corp.
Kakao Enterprise Corp.
[ICCV 2021 Oral] Just Ask: Learning to Answer Questions from Millions of Narrated Videos

Just Ask: Learning to Answer Questions from Millions of Narrated Videos Webpage • Demo • Paper This repository provides the code for our paper, includ

Antoine Yang 87 Jan 05, 2023
Multiple style transfer via variational autoencoder

ST-VAE Multiple style transfer via variational autoencoder By Zhi-Song Liu, Vicky Kalogeiton and Marie-Paule Cani This repo only provides simple testi

13 Oct 29, 2022
Audio2Face - Audio To Face With Python

Audio2Face Discription We create a project that transforms audio to blendshape w

FACEGOOD 724 Dec 26, 2022
Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network

Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network This repository is the official implementation of Speech Separati

Kai Li (李凯) 116 Nov 09, 2022
PyTorch implementation for our paper Learning Character-Agnostic Motion for Motion Retargeting in 2D, SIGGRAPH 2019

Learning Character-Agnostic Motion for Motion Retargeting in 2D We provide PyTorch implementation for our paper Learning Character-Agnostic Motion for

Rundi Wu 367 Dec 22, 2022
Source code for GNN-LSPE (Graph Neural Networks with Learnable Structural and Positional Representations)

Graph Neural Networks with Learnable Structural and Positional Representations Source code for the paper "Graph Neural Networks with Learnable Structu

Vijay Prakash Dwivedi 180 Dec 22, 2022
This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models This repository contains the code for our paper VDA (publ

RUCAIBox 13 Aug 06, 2022
This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine

LSHTM_RCS This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine (LSHTM) in collabo

Lukas Kopecky 3 Jan 30, 2022
Python library containing BART query generation and BERT-based Siamese models for neural retrieval.

Neural Retrieval Embedding-based Zero-shot Retrieval through Query Generation leverages query synthesis over large corpuses of unlabeled text (such as

Amazon Web Services - Labs 35 Apr 14, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

32 Sep 21, 2022
Prototype python implementation of the ome-ngff table spec

Prototype python implementation of the ome-ngff table spec

Kevin Yamauchi 8 Nov 20, 2022
JugLab 33 Dec 30, 2022
Spline is a tool that is capable of running locally as well as part of well known pipelines like Jenkins (Jenkinsfile), Travis CI (.travis.yml) or similar ones.

Welcome to spline - the pipeline tool Important note: Since change in my job I didn't had the chance to continue on this project. My main new project

Thomas Lehmann 29 Aug 22, 2022
Doods2 - API for detecting objects in images and video streams using Tensorflow

DOODS2 - Return of DOODS Dedicated Open Object Detection Service - Yes, it's a b

Zach 101 Jan 04, 2023
Colossal-AI: A Unified Deep Learning System for Large-Scale Parallel Training

ColossalAI An integrated large-scale model training system with efficient parallelization techniques. arXiv: Colossal-AI: A Unified Deep Learning Syst

HPC-AI Tech 7.9k Jan 08, 2023
The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate.

The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate. Website • Key Features • How To Use • Docs •

Pytorch Lightning 21.1k Jan 01, 2023
ALIbaba's Collection of Encoder-decoders from MinD (Machine IntelligeNce of Damo) Lab

AliceMind AliceMind: ALIbaba's Collection of Encoder-decoders from MinD (Machine IntelligeNce of Damo) Lab This repository provides pre-trained encode

Alibaba 1.4k Jan 01, 2023
GLODISMO: Gradient-Based Learning of Discrete Structured Measurement Operators for Signal Recovery

GLODISMO: Gradient-Based Learning of Discrete Structured Measurement Operators for Signal Recovery This is the code to the paper: Gradient-Based Learn

3 Feb 15, 2022
DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs

DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs Abstract: Image-to-image translation has recently achieved re

yaxingwang 23 Apr 14, 2022
RINDNet: Edge Detection for Discontinuity in Reflectance, Illumination, Normal and Depth, in ICCV 2021 (oral)

RINDNet RINDNet: Edge Detection for Discontinuity in Reflectance, Illumination, Normal and Depth Mengyang Pu, Yaping Huang, Qingji Guan and Haibin Lin

Mengyang Pu 75 Dec 15, 2022