Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Overview

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage

This repository provides the official PyTorch implementation of the following paper:

Learning Debiased Representation via Disentangled Feature Augmentation
Jungsoo Lee* (KAIST AI, Kakao Enterprise), Eungyeup Kim* (KAIST AI, Kakao Enterprise),
Juyoung Lee (Kakao Enterprise), Jihyeon Lee (KAIST AI), and Jaegul Choo (KAIST AI)
(* indicates equal contribution. The order of first authors was chosen by tossing a coin.)
NeurIPS 2021, Oral

Paper: Arxiv

Abstract: Image classification models tend to make decisions based on peripheral attributes of data items that have strong correlation with a target variable (i.e., dataset bias). These biased models suffer from the poor generalization capability when evaluated on unbiased datasets. Existing approaches for debiasing often identify and emphasize those samples with no such correlation (i.e., bias-conflicting) without defining the bias type in advance. However, such bias-conflicting samples are significantly scarce in biased datasets, limiting the debiasing capability of these approaches. This paper first presents an empirical analysis revealing that training with "diverse" bias-conflicting samples beyond a given training set is crucial for debiasing as well as the generalization capability. Based on this observation, we propose a novel feature-level data augmentation technique in order to synthesize diverse bias-conflicting samples. To this end, our method learns the disentangled representation of (1) the intrinsic attributes (i.e., those inherently defining a certain class) and (2) bias attributes (i.e., peripheral attributes causing the bias), from a large number of bias-aligned samples, the bias attributes of which have strong correlation with the target variable. Using the disentangled representation, we synthesize bias-conflicting samples that contain the diverse intrinsic attributes of bias-aligned samples by swapping their latent features. By utilizing these diversified bias-conflicting features during the training, our approach achieves superior classification accuracy and debiasing results against the existing baselines on both synthetic as well as a real-world dataset.

Code Contributors

Jungsoo Lee [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Eungyeup Kim [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Juyoung Lee [Website] (Kakao Enterprise)

Pytorch Implementation

Installation

Clone this repository.

git clone https://github.com/kakaoenterprise/Learning-Debiased-Disentangled.git
cd Learning-Debiased-Disentangled
pip install -r requirements.txt

Datasets

We used three datasets in our paper.

Download the datasets with the following url. Note that BFFHQ is the dataset used in "BiaSwap: Removing Dataset Bias with Bias-Tailored Swapping Augmentation" (Kim et al., ICCV 2021). Unzip the files and the directory structures will be as following:

cmnist
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
cifar10c
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
bffhq
 └ 0.5pct
 └ valid
 └ test

How to Run

CMNIST

Vanilla
python train.py --dataset cmnist --exp=cmnist_0.5_vanilla --lr=0.01 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_vanilla --lr=0.01 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_vanilla --lr=0.01 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_vanilla --lr=0.01 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cmnist_vanilla.sh
Ours
python train.py --dataset cmnist --exp=cmnist_0.5_ours --lr=0.01 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_ours --lr=0.01 --percent=1pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_ours --lr=0.01 --percent=2pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_ours --lr=0.01 --percent=5pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cmnist_ours.sh

Corrupted CIFAR10

Vanilla
python train.py --dataset cifar10c --exp=cifar10c_0.5_vanilla --lr=0.001 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_vanilla --lr=0.001 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_vanilla --lr=0.001 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_vanilla --lr=0.001 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cifar10c_vanilla.sh
Ours
python train.py --dataset cifar10c --exp=cifar10c_0.5_ours --lr=0.0005 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_ours --lr=0.001 --percent=1pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_ours --lr=0.001 --percent=2pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_ours --lr=0.001 --percent=5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cifar10c_ours.sh

BFFHQ

Vanilla
python train.py --dataset bffhq --exp=bffhq_0.5_vanilla --lr=0.0001 --percent=0.5pct --train_vanilla --tensorboard --wandb
bash scripts/run_bffhq_vanilla.sh
Ours
python train.py --dataset bffhq --exp=bffhq_0.5_ours --lr=0.0001 --percent=0.5pct --lambda_swap=0.1 --curr_step=10000 --use_lr_decay --lr_decay_step=10000 --lambda_dis_align 2. --lambda_swap_align 2. --dataset bffhq --train_ours --tensorboard --wandb
bash scripts/run_bffhq_ours.sh

Pretrained Models

In order to test our pretrained models, run the following command.

python test.py --pretrained_path=
   
     --dataset=
    
      --percent=
     

     
    
   

We provide the pretrained models in the following urls.
CMNIST 0.5pct
CMNIST 1pct
CMNIST 2pct
CMNIST 5pct

CIFAR10C 0.5pct
CIFAR10C 1pct
CIFAR10C 2pct
CIFAR10C 5pct

BFFHQ 0.5pct

Citations

Bibtex coming soon!

Contact

Jungsoo Lee

Eungyeup Kim

Juyoung Lee

Kakao Enterprise/Vision Team

Acknowledgments

This work was mainly done when both of the first authors were doing internship at Vision Team/AI Lab/Kakao Enterprise. Our pytorch implementation is based on LfF. Thanks for the implementation.

Owner
Kakao Enterprise Corp.
Kakao Enterprise Corp.
The official implementation of paper Siamese Transformer Pyramid Networks for Real-Time UAV Tracking, accepted by WACV22

SiamTPN Introduction This is the official implementation of the SiamTPN (WACV2022). The tracker intergrates pyramid feature network and transformer in

Robotics and Intelligent Systems Control @ NYUAD 28 Nov 25, 2022
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Malik Boudiaf 138 Dec 12, 2022
EMNLP 2021 Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections

Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections Ruiqi Zhong, Kristy Lee*, Zheng Zhang*, Dan Klein EMN

Ruiqi Zhong 42 Nov 03, 2022
A demonstration of using a live Tensorflow session to create an interactive face-GAN explorer.

Streamlit Demo: The Controllable GAN Face Generator This project highlights Streamlit's new hash_func feature with an app that calls on TensorFlow to

Streamlit 257 Dec 31, 2022
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
Data cleaning, missing value handle, EDA use in this project

Lending Club Case Study Project Brief Solving this assignment will give you an idea about how real business problems are solved using EDA. In this cas

Dhruvil Sheth 1 Jan 05, 2022
Official PyTorch implementation of our AAAI22 paper: TransMEF: A Transformer-Based Multi-Exposure Image Fusion Framework via Self-Supervised Multi-Task Learning. Code will be available soon.

Official-PyTorch-Implementation-of-TransMEF Official PyTorch implementation of our AAAI22 paper: TransMEF: A Transformer-Based Multi-Exposure Image Fu

117 Dec 27, 2022
MoCoGAN: Decomposing Motion and Content for Video Generation

MoCoGAN: Decomposing Motion and Content for Video Generation This repository contains an implementation and further details of MoCoGAN: Decomposing Mo

Sergey Tulyakov 514 Dec 18, 2022
Awesome-google-colab - Google Colaboratory Notebooks and Repositories

Unofficial Google Colaboratory Notebook and Repository Gallery Please contact me to take over and revamp this repo (it gets around 30k views and 200k

Derek Snow 1.2k Jan 03, 2023
Official pytorch implementation for Learning to Listen: Modeling Non-Deterministic Dyadic Facial Motion (CVPR 2022)

Learning to Listen: Modeling Non-Deterministic Dyadic Facial Motion This repository contains a pytorch implementation of "Learning to Listen: Modeling

50 Dec 17, 2022
🥇 LG-AI-Challenge 2022 1위 솔루션 입니다.

LG-AI-Challenge-for-Plant-Classification Dacon에서 진행된 농업 환경 변화에 따른 작물 병해 진단 AI 경진대회 에 대한 코드입니다. (colab directory에 코드가 잘 정리 되어있습니다.) Requirements python

siwooyong 10 Jun 30, 2022
This project aims at providing a concise, easy-to-use, modifiable reference implementation for semantic segmentation models using PyTorch.

Semantic Segmentation on PyTorch (include FCN, PSPNet, Deeplabv3, Deeplabv3+, DANet, DenseASPP, BiSeNet, EncNet, DUNet, ICNet, ENet, OCNet, CCNet, PSANet, CGNet, ESPNet, LEDNet, DFANet)

2.4k Jan 08, 2023
ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin et al., 2020).

ReConsider ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin

Facebook Research 47 Jul 26, 2022
A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

Ayushman Dash 93 Aug 04, 2022
A PyTorch implementation of "Pathfinder Discovery Networks for Neural Message Passing"

A PyTorch implementation of "Pathfinder Discovery Networks for Neural Message Passing" (WebConf 2021). Abstract In this work we propose Pathfind

Benedek Rozemberczki 49 Dec 01, 2022
Free-duolingo-plus - Duolingo account creator that uses your invite code to get you free duolingo plus

free-duolingo-plus duolingo account creator that uses your invite code to get yo

1 Jan 06, 2022
From the basics to slightly more interesting applications of Tensorflow

TensorFlow Tutorials You can find python source code under the python directory, and associated notebooks under notebooks. Source code Description 1 b

Parag K Mital 5.6k Jan 09, 2023
Dense Deep Unfolding Network with 3D-CNN Prior for Snapshot Compressive Imaging, ICCV2021 [PyTorch Code]

Dense Deep Unfolding Network with 3D-CNN Prior for Snapshot Compressive Imaging, ICCV2021 [PyTorch Code]

Jian Zhang 20 Oct 24, 2022
The official implementation code of "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction."

PlantStereo This is the official implementation code for the paper "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction".

Wang Qingyu 14 Nov 28, 2022
PyTorch implementation of the paper Ultra Fast Structure-aware Deep Lane Detection

PyTorch implementation of the paper Ultra Fast Structure-aware Deep Lane Detection

1.4k Jan 06, 2023