Code accompanying the paper on "An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers" published at NeurIPS, 2021

Overview

Code for "An Empirical Investigation of Domian Generalization with Empirical Risk Minimizers" (NeurIPS 2021)

Motivation and Introduction

Domain Generalization is a task in machine learning where given a shift in the input data distribution, one is expected to perform well on a test task with a different input data distribution. For example, one might train a digit classifier on MNIST data and ask the model to generalize to predict digits that are rotated by say 30 degrees.

While many approaches have been proposed for this problem, we were intrigued by the results on the DomainBed benchmark which suggested that using the simple, empirical risk minimization (ERM) with a proper hyperparameter sweep leads to performance close to state of the art on Domain Generalization Problems.

What governs the generalization of a trained deep learning model using ERM to a given data distribution? This is the question we seek to answer in our NeurIPS 2021 paper:

An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers. Rama Vedantam, David Lopez-Paz*, David Schwab*.

NeurIPS 2021 (*=Equal Contribution)

This repository contains code used for producing the results in our paper.

Initial Setup

  1. Run source init.sh to install all the dependencies for the project. This will also initialize DomainBed as a submodule for the project

  2. Set requisite paths in setup.sh, and run source setup.sh

Computing Generalization Measures

  • Get set up with the DomainBed codebase and launch a sweep for an initial set of trained models (illustrated below for rotated MNIST dataset):
cd DomainBed/

python -m domainbed.scripts.sweep launch\
       --data_dir=${DOMAINBED_DATA} \
       --output_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
       --algorithms=ERM \
       --holdout_fraction=0.5\
       --datasets=RotatedMNIST \
       --n_hparams=1\
       --command_launcher submitit

After this step, we have a set of trained models that we can now look to evaluate and measure. Note that unlike the original domainbed paper we holdout a larger fraction (50%) of the data for evaluation of the measures.

  • Once the sweep finishes, aggregate the different files for use by the domianbed_measures codebase:
python domainbed_measures/write_job_status_file.py \
                --sweep_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
                --output_txt="domainbed_measures/scratch/sweep_release.txt"
  • Once this step is complete, we can compute various generalization measures and store them to disk for future analysis using:
SLURM_PARTITION="TO_BE_SET"
python domainbed_measures/compute_gen_correlations.py \
	--algorithm=ERM \
    --job_done_file="domainbed_measures/scratch/sweep_release.txt" \
    --run_dir=${MEASURE_RUN_DIR} \
    --all_measures_one_job \
	--slurm_partition=${SLURM_PARTITION}

Where we utilize slurm on a compute cluster to scale the experiments to thousands of models. If you do not have access to such a cluster with multiple GPUs to parallelize the computation, use --slurm_partition="" above and the code will run on a single GPU (although the results might take a long time to compute!).

  • Finally, once the above code is done, use the following code snippet to aggregate the values of the different generalization measures:
python domainbed_measures/extract_generalization_features.py \
    --run_dir=${MEASURE_RUN_DIR} \
    --sweep_name="_out_ERM_RotatedMNIST"

This step yeilds .csv files where each row corresponds to a given trained model. Each row overall has the following format:

dataset | test_envs | measure 1 | measure 2 | measure 3 | target_err

where:

  • test_envs specifies which environments the model is tested on or equivalently trained on, since the remaining environments are used for training
  • target_err specifies the target error value for regression
  • measure 1 specifies the which measure is being computed, e.g. sharpness or fisher eigen value based measures

In case of the file named, for example, sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv, the validation error within domain wd_out_domain_err is also used as one of the measures and target_err is the out of domain generalization error, and all measures are computed on a held-out set of image inputs from the target domain (for more details see the paper).

Alternatively, in case of the file named, sweeps__out_ERM_RotatedMNIST_canon_False_wd.csv, the target_err is the validation accuracy in domain, and all the measures are computed on the in-distribution held-out images.

  • Using this file one can do a number of interesting regression analyses as reported in the paper for measuring generalization.

For example, to generate the kind of results in Table. 1 of the paper in the joint setting, run the following command options:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="joint"\
    --num_features=2 \
    --fix_one_feature_to_wd

Alternatively, to generate results in the stratified setting, run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="stratified"\
    --num_features=2 \
    --fix_one_feature_to_wd

Finally, to generate results using a single feature (Alone setting in Table. 1), run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --num_features=1

Translation of measures from the code to the paper

The following table illustrates all the measures in the paper (Appendix Table. 2) and how they are referred to in the codebase:

Measure Name Code Reference
H-divergence c2st
H-divergence + Source Error c2st_perr
H-divergence MS c2st_per_env
H-divergence MS + Source Error c2st_per_env_perr
H-divergence (train) c2st_train
H-divergence (train) + Source Error c2st_train_perr
H-divergence (train) MS c2st_train_per_env
Entropy-Source or Entropy entropy
Entropy-Target entropy_held_out
Fisher-Eigval-Diff fisher_eigval_sum_diff_ex_75
Fisher-Eigval fisher_eigval_sum_ex_75
Fisher-Align or Fisher (main paper) fisher_eigvec_align_ex_75
HΔH-divergence SS hdh
HΔH-divergence SS + Source Error hdh_perr
HΔH-divergence MS hdh_per_env
HΔH-divergence MS + Source Error hdh_per_env_perr
HΔH-divergence (train) SS hdh_train
HΔH-divergence (train) SS + Source Error hdh_train_perr
Jacobian jacobian_norm
Jacobian Ratio jacobian_norm_relative
Jacobian Diff jacobian_norm_relative_diff
Jacobian Log Ratio jacobian_norm_relative_log_diff
Mixup mixup
Mixup Ratio mixup_relative
Mixup Diff mixup_relative_diff
Mixup Log Ratio mixup_relative_log_diff
MMD-Gaussian mmd_gaussian
MMD-Mean-Cov mmd_mean_cov
L2-Path-Norm. path_norm
Sharpness sharp_mag
H+-divergence SS v_plus_c2st
H+-divergence SS + Source Error v_plus_c2st_perr
H+-divergence MS v_plus_c2st_per_env
H+-divergence MS + Source Error v_plus_c2st_per_env_perr
H+ΔH+-divergence SS v_plus_hdh
H+ΔH+-divergence SS + Source Error v_plus_hdh_perr
H+ΔH+-divergence MS v_plus_hdh_per_env
H+ΔH+-divergence MS + Source Error v_plus_hdh_per_env_perr
Source Error wd_out_domain_err

Acknowledgments

We thank the developers of Decodable Information Bottleneck, Domain Bed and Jonathan Frankle for code we found useful for this project.

License

This source code is released under the Creative Commons Attribution-NonCommercial 4.0 International license, included here.

Owner
Meta Research
Meta Research
Pyramid Grafting Network for One-Stage High Resolution Saliency Detection. CVPR 2022

PGNet Pyramid Grafting Network for One-Stage High Resolution Saliency Detection. CVPR 2022, CVPR 2022 (arXiv 2204.05041) Abstract Recent salient objec

CVTEAM 109 Dec 05, 2022
Attention-guided gan for synthesizing IR images

SI-AGAN Attention-guided gan for synthesizing IR images This repository contains the Tensorflow code for "Pedestrian Gender Recognition by Style Trans

1 Oct 25, 2021
Source code of our work: "Benchmarking Deep Models for Salient Object Detection"

SALOD Source code of our work: "Benchmarking Deep Models for Salient Object Detection". In this works, we propose a new benchmark for SALient Object D

22 Dec 30, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 03, 2022
A fast, dataset-agnostic, deep visual search engine for digital art history

imgs.ai imgs.ai is a fast, dataset-agnostic, deep visual search engine for digital art history based on neural network embeddings. It utilizes modern

Fabian Offert 5 Dec 14, 2022
[SIGGRAPH Asia 2021] Pose with Style: Detail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN

Pose with Style: Detail-Preserving Pose-Guided Image Synthesis with Conditional StyleGAN [Paper] [Project Website] [Output resutls] Official Pytorch i

Badour AlBahar 215 Dec 17, 2022
Paddle Graph Learning (PGL) is an efficient and flexible graph learning framework based on PaddlePaddle

DOC | Quick Start | 中文 Breaking News !! 🔥 🔥 🔥 OGB-LSC KDD CUP 2021 winners announced!! (2021.06.17) Super excited to announce our PGL team won TWO

1.5k Jan 06, 2023
We will release the code of "ConTNet: Why not use convolution and transformer at the same time?" in this repo

ConTNet Introduction ConTNet (Convlution-Tranformer Network) is proposed mainly in response to the following two issues: (1) ConvNets lack a large rec

93 Nov 08, 2022
Keras implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 8.9k Jan 04, 2023
Official MegEngine implementation of CREStereo(CVPR 2022 Oral).

[CVPR 2022] Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation This repository contains MegEngine implementation of ou

MEGVII Research 309 Dec 30, 2022
Source code for deep symbolic optimization.

Update July 10, 2021: This repository now supports an additional symbolic optimization task: learning symbolic policies for reinforcement learning. Th

Brenden Petersen 290 Dec 25, 2022
Code for "Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks", CVPR 2021

Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks This repository contains the code that accompanies our CVPR 20

Despoina Paschalidou 161 Dec 20, 2022
A Closer Look at Invalid Action Masking in Policy Gradient Algorithms

A Closer Look at Invalid Action Masking in Policy Gradient Algorithms This repo contains the source code to reproduce the results in the paper A Close

Costa Huang 73 Dec 24, 2022
code from "Tensor decomposition of higher-order correlations by nonlinear Hebbian plasticity"

Code associated with the paper "Tensor decomposition of higher-order correlations by nonlinear Hebbian learning," Ocker & Buice, Neurips 2021. "plot_f

Gabriel Koch Ocker 4 Oct 16, 2022
ViDT: An Efficient and Effective Fully Transformer-based Object Detector

ViDT: An Efficient and Effective Fully Transformer-based Object Detector by Hwanjun Song1, Deqing Sun2, Sanghyuk Chun1, Varun Jampani2, Dongyoon Han1,

NAVER AI 262 Dec 27, 2022
Best Practices on Recommendation Systems

Recommenders What's New (February 4, 2021) We have a new relase Recommenders 2021.2! It comes with lots of bug fixes, optimizations and 3 new algorith

Microsoft 14.8k Jan 03, 2023
PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML)

pytorch-maml This is a PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML): https://arxiv

Kate Rakelly 516 Jan 05, 2023
Explaining Deep Neural Networks - A comparison of different CAM methods based on an insect data set

Explaining Deep Neural Networks - A comparison of different CAM methods based on an insect data set This is the repository for the Deep Learning proje

Robert Krug 3 Feb 06, 2022
Calculates JMA (Japan Meteorological Agency) seismic intensity (shindo) scale from acceleration data recorded in NumPy array

shindo.py Calculates JMA (Japan Meteorological Agency) seismic intensity (shindo) scale from acceleration data stored in NumPy array Introduction Japa

RR_Inyo 3 Sep 23, 2022
TC-GNN with Pytorch integration

TC-GNN (Running Sparse GNN on Dense Tensor Core on Ampere GPU) Cite this project and paper. @inproceedings{TC-GNN, title={TC-GNN: Accelerating Spars

YUKE WANG 19 Dec 01, 2022