Code accompanying the paper on "An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers" published at NeurIPS, 2021

Overview

Code for "An Empirical Investigation of Domian Generalization with Empirical Risk Minimizers" (NeurIPS 2021)

Motivation and Introduction

Domain Generalization is a task in machine learning where given a shift in the input data distribution, one is expected to perform well on a test task with a different input data distribution. For example, one might train a digit classifier on MNIST data and ask the model to generalize to predict digits that are rotated by say 30 degrees.

While many approaches have been proposed for this problem, we were intrigued by the results on the DomainBed benchmark which suggested that using the simple, empirical risk minimization (ERM) with a proper hyperparameter sweep leads to performance close to state of the art on Domain Generalization Problems.

What governs the generalization of a trained deep learning model using ERM to a given data distribution? This is the question we seek to answer in our NeurIPS 2021 paper:

An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers. Rama Vedantam, David Lopez-Paz*, David Schwab*.

NeurIPS 2021 (*=Equal Contribution)

This repository contains code used for producing the results in our paper.

Initial Setup

  1. Run source init.sh to install all the dependencies for the project. This will also initialize DomainBed as a submodule for the project

  2. Set requisite paths in setup.sh, and run source setup.sh

Computing Generalization Measures

  • Get set up with the DomainBed codebase and launch a sweep for an initial set of trained models (illustrated below for rotated MNIST dataset):
cd DomainBed/

python -m domainbed.scripts.sweep launch\
       --data_dir=${DOMAINBED_DATA} \
       --output_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
       --algorithms=ERM \
       --holdout_fraction=0.5\
       --datasets=RotatedMNIST \
       --n_hparams=1\
       --command_launcher submitit

After this step, we have a set of trained models that we can now look to evaluate and measure. Note that unlike the original domainbed paper we holdout a larger fraction (50%) of the data for evaluation of the measures.

  • Once the sweep finishes, aggregate the different files for use by the domianbed_measures codebase:
python domainbed_measures/write_job_status_file.py \
                --sweep_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
                --output_txt="domainbed_measures/scratch/sweep_release.txt"
  • Once this step is complete, we can compute various generalization measures and store them to disk for future analysis using:
SLURM_PARTITION="TO_BE_SET"
python domainbed_measures/compute_gen_correlations.py \
	--algorithm=ERM \
    --job_done_file="domainbed_measures/scratch/sweep_release.txt" \
    --run_dir=${MEASURE_RUN_DIR} \
    --all_measures_one_job \
	--slurm_partition=${SLURM_PARTITION}

Where we utilize slurm on a compute cluster to scale the experiments to thousands of models. If you do not have access to such a cluster with multiple GPUs to parallelize the computation, use --slurm_partition="" above and the code will run on a single GPU (although the results might take a long time to compute!).

  • Finally, once the above code is done, use the following code snippet to aggregate the values of the different generalization measures:
python domainbed_measures/extract_generalization_features.py \
    --run_dir=${MEASURE_RUN_DIR} \
    --sweep_name="_out_ERM_RotatedMNIST"

This step yeilds .csv files where each row corresponds to a given trained model. Each row overall has the following format:

dataset | test_envs | measure 1 | measure 2 | measure 3 | target_err

where:

  • test_envs specifies which environments the model is tested on or equivalently trained on, since the remaining environments are used for training
  • target_err specifies the target error value for regression
  • measure 1 specifies the which measure is being computed, e.g. sharpness or fisher eigen value based measures

In case of the file named, for example, sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv, the validation error within domain wd_out_domain_err is also used as one of the measures and target_err is the out of domain generalization error, and all measures are computed on a held-out set of image inputs from the target domain (for more details see the paper).

Alternatively, in case of the file named, sweeps__out_ERM_RotatedMNIST_canon_False_wd.csv, the target_err is the validation accuracy in domain, and all the measures are computed on the in-distribution held-out images.

  • Using this file one can do a number of interesting regression analyses as reported in the paper for measuring generalization.

For example, to generate the kind of results in Table. 1 of the paper in the joint setting, run the following command options:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="joint"\
    --num_features=2 \
    --fix_one_feature_to_wd

Alternatively, to generate results in the stratified setting, run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="stratified"\
    --num_features=2 \
    --fix_one_feature_to_wd

Finally, to generate results using a single feature (Alone setting in Table. 1), run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --num_features=1

Translation of measures from the code to the paper

The following table illustrates all the measures in the paper (Appendix Table. 2) and how they are referred to in the codebase:

Measure Name Code Reference
H-divergence c2st
H-divergence + Source Error c2st_perr
H-divergence MS c2st_per_env
H-divergence MS + Source Error c2st_per_env_perr
H-divergence (train) c2st_train
H-divergence (train) + Source Error c2st_train_perr
H-divergence (train) MS c2st_train_per_env
Entropy-Source or Entropy entropy
Entropy-Target entropy_held_out
Fisher-Eigval-Diff fisher_eigval_sum_diff_ex_75
Fisher-Eigval fisher_eigval_sum_ex_75
Fisher-Align or Fisher (main paper) fisher_eigvec_align_ex_75
HΔH-divergence SS hdh
HΔH-divergence SS + Source Error hdh_perr
HΔH-divergence MS hdh_per_env
HΔH-divergence MS + Source Error hdh_per_env_perr
HΔH-divergence (train) SS hdh_train
HΔH-divergence (train) SS + Source Error hdh_train_perr
Jacobian jacobian_norm
Jacobian Ratio jacobian_norm_relative
Jacobian Diff jacobian_norm_relative_diff
Jacobian Log Ratio jacobian_norm_relative_log_diff
Mixup mixup
Mixup Ratio mixup_relative
Mixup Diff mixup_relative_diff
Mixup Log Ratio mixup_relative_log_diff
MMD-Gaussian mmd_gaussian
MMD-Mean-Cov mmd_mean_cov
L2-Path-Norm. path_norm
Sharpness sharp_mag
H+-divergence SS v_plus_c2st
H+-divergence SS + Source Error v_plus_c2st_perr
H+-divergence MS v_plus_c2st_per_env
H+-divergence MS + Source Error v_plus_c2st_per_env_perr
H+ΔH+-divergence SS v_plus_hdh
H+ΔH+-divergence SS + Source Error v_plus_hdh_perr
H+ΔH+-divergence MS v_plus_hdh_per_env
H+ΔH+-divergence MS + Source Error v_plus_hdh_per_env_perr
Source Error wd_out_domain_err

Acknowledgments

We thank the developers of Decodable Information Bottleneck, Domain Bed and Jonathan Frankle for code we found useful for this project.

License

This source code is released under the Creative Commons Attribution-NonCommercial 4.0 International license, included here.

Owner
Meta Research
Meta Research
Fine-grained Post-training for Improving Retrieval-based Dialogue Systems - NAACL 2021

Fine-grained Post-training for Multi-turn Response Selection Implements the model described in the following paper Fine-grained Post-training for Impr

Janghoon Han 83 Dec 20, 2022
Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021)

Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021) The implementation of Reducing Infromation Bottleneck for W

Jungbeom Lee 81 Dec 16, 2022
(3DV 2021 Oral) Filtering by Cluster Consistency for Large-Scale Multi-Image Matching

Scalable Cluster-Consistency Statistics for Robust Multi-Object Matching (3DV 2021 Oral Presentation) Filtering by Cluster Consistency (FCC) is a very

Yunpeng Shi 11 Sep 28, 2022
Game Agent Framework. Helping you create AIs / Bots that learn to play any game you own!

Serpent.AI - Game Agent Framework (Python) Update: Revival (May 2020) Development work has resumed on the framework with the aim of bringing it into 2

Serpent.AI 6.4k Jan 05, 2023
Recommendationsystem - Movie-recommendation - matrixfactorization colloborative filtering recommendation system user

recommendationsystem matrixfactorization colloborative filtering recommendation

kunal jagdish madavi 1 Jan 01, 2022
Convnet transfer - Code for paper How transferable are features in deep neural networks?

How transferable are features in deep neural networks? This repository contains source code necessary to reproduce the results presented in the follow

Jason Yosinski 143 Sep 13, 2022
This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".

Artistic Style Transfer with Internal-external Learning and Contrastive Learning This is the official PyTorch implementation of our paper: "Artistic S

51 Dec 20, 2022
R3Det based on mmdet 2.19.0

R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object Installation # install mmdetection first if you haven't installed it

SJTU-Thinklab-Det 38 Dec 15, 2022
Lenia - Mathematical Life Forms

For full version list, see Timeline in Lenia portal [2020-10-13] Update Python version with multi-kernel and multi-channel extensions (v3.4 LeniaNDK.p

Bert Chan 3.1k Dec 28, 2022
The official codes for the ICCV2021 Oral presentation "Rethinking Counting and Localization in Crowds: A Purely Point-Based Framework"

P2PNet (ICCV2021 Oral Presentation) This repository contains codes for the official implementation in PyTorch of P2PNet as described in Rethinking Cou

Tencent YouTu Research 208 Dec 26, 2022
A stable algorithm for GAN training

DRAGAN (Deep Regret Analytic Generative Adversarial Networks) Link to our paper - https://arxiv.org/abs/1705.07215 Pytorch implementation (thanks!) -

195 Oct 10, 2022
An adaptive hierarchical energy management strategy for hybrid electric vehicles

An adaptive hierarchical energy management strategy This project contains the source code of an adaptive hierarchical EMS combining heuristic equivale

19 Dec 13, 2022
Implementation of Pix2Seq in PyTorch

pix2seq-pytorch Implementation of Pix2Seq paper Different from the paper image input size 1280 bin size 1280 LambdaLR scheduler used instead of Linear

Tony Shin 9 Dec 15, 2022
Simple and understandable swin-transformer OCR project

swin-transformer-ocr ocr with swin-transformer Overview Simple and understandable swin-transformer OCR project. The model in this repository heavily r

Ha YongWook 67 Dec 31, 2022
(ImageNet pretrained models) The official pytorch implemention of the TPAMI paper "Res2Net: A New Multi-scale Backbone Architecture"

Res2Net The official pytorch implemention of the paper "Res2Net: A New Multi-scale Backbone Architecture" Our paper is accepted by IEEE Transactions o

Res2Net Applications 928 Dec 29, 2022
Groceries ARL: Association Rules (Birliktelik Kuralı)

Groceries_ARL Association Rules (Birliktelik Kuralı) Birliktelik kuralları, mark

Şebnem 5 Feb 08, 2022
Multi-task Multi-agent Soft Actor Critic for SMAC

Multi-task Multi-agent Soft Actor Critic for SMAC Overview The CARE formulti-task: Multi-Task Reinforcement Learning with Context-based Representation

RuanJingqing 8 Sep 30, 2022
A framework for the elicitation, specification, formalization and understanding of requirements.

A framework for the elicitation, specification, formalization and understanding of requirements.

NASA - Software V&V 161 Jan 03, 2023
[SIGGRAPH Asia 2019] Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning

AGIS-Net Introduction This is the official PyTorch implementation of the Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning. paper | suppl

Yue Gao 102 Jan 02, 2023
BitPack is a practical tool to efficiently save ultra-low precision/mixed-precision quantized models.

BitPack is a practical tool that can efficiently save quantized neural network models with mixed bitwidth.

Zhen Dong 36 Dec 02, 2022