Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Related tags

Deep Learningfishr
Overview

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization | paper

Alexandre Ramé, Corentin Dancette, Matthieu Cord

Abstract

Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols.

In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights.

Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than Empirical Risk Minimization.

Installation

Requirements overview

Our implementation relies on the BackPACK package in PyTorch to easily compute gradient variances.

  • python == 3.7.10
  • torch == 1.8.1
  • torchvision == 0.9.1
  • backpack-for-pytorch == 1.3.0
  • numpy == 1.20.2

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/fishr.git
  1. Install this repository and the dependencies using pip:
$ conda create --name fishr python=3.7.10
$ conda activate fishr
$ cd fishr
$ pip install -r requirements.txt

With this, you can edit the Fishr code on the fly.

Overview

This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by IRM and (2) on the DomainBed benchmark.

Colored MNIST in the IRM setup

We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST.

Main results (Table 2 in Section 6.A)

To reproduce the results from Table 2, call python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm where algorithm is either:

Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal IRM implementation.

    Method | Train acc. | Test acc.  | Gray test acc.
   --------|------------|------------|----------------
    ERM    | 86.4 ± 0.2 | 14.0 ± 0.7 |   71.0 ± 0.7
    IRM    | 71.0 ± 0.5 | 65.6 ± 1.8 |   66.1 ± 0.2
    V-REx  | 71.7 ± 1.5 | 67.2 ± 1.5 |   68.6 ± 2.2
    Fishr  | 71.0 ± 0.9 | 69.5 ± 1.0 |   70.2 ± 1.1

Without label flipping (Table 5 in Appendix C.2.3)

The script coloredmnist.train_coloredmnist also accepts as input the argument --label_flipping_prob which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set --label_flipping_prob 0.

Fishr variants (Table 6 in Appendix C.2.4)

This table considers two additional Fishr variants, reproduced with algorithm set to:

  • fishr_offdiagonal for Fishr but without centering the gradient variances
  • fishr_notcentered for Fishr but on the full covariance rather than only the diagonal

DomainBed

DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization. Instructions below are copied and adapted from the official github.

Algorithms and hyperparameter grids

We added Fishr as a new algorithm here, and defined Fishr's hyperparameter grids here, as defined in Table 7 in Appendix D.

Datasets

We ran Fishr on following datasets:

Launch training

Download the datasets:

python3 -m domainbed.scripts.download\
       --data_dir=/my/data/dir

Train a model for debugging:

python3 -m domainbed.scripts.train\
       --data_dir=/my/data/dir/\
       --algorithm Fishr\
       --dataset ColoredMNIST\
       --test_env 2

Launch a sweep for hyperparameter search:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher
       --datasets ColoredMNIST\
       --algorithms Fishr

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py.

Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G)

To view the results of your sweep:

python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path

We inspect performances using following model selection criteria, that differ in what data is used to choose the best hyper-parameters for a given model:

  • OracleSelectionMethod (Oracle): A random subset from the data of the test domain.
  • IIDAccuracySelectionMethod (Training): A random subset from the data of the training domains.

Critically, Fishr performs consistently better than Empirical Risk Minimization.

Model selection Algorithm Colored MNIST Rotated MNIST VLCS PACS OfficeHome TerraIncognita DomainNet Avg
Oracle ERM 57.8 ± 0.2 97.8 ± 0.1 77.6 ± 0.3 86.7 ± 0.3 66.4 ± 0.5 53.0 ± 0.3 41.3 ± 0.1 68.7
Oracle Fishr 68.8 ± 1.4 97.8 ± 0.1 78.2 ± 0.2 86.9 ± 0.2 68.2 ± 0.2 53.6 ± 0.4 41.8 ± 0.2 70.8
Training ERM 51.5 ± 0.1 98.0 ± 0.0 77.5 ± 0.4 85.5 ± 0.2 66.5 ± 0.3 46.1 ± 1.8 40.9 ± 0.1 66.6
Training Fishr 52.0 ± 0.2 97.8 ± 0.0 77.8 ± 0.1 85.5 ± 0.4 67.8 ± 0.1 47.4 ± 1.6 41.7 ± 0.0 67.1

Conclusion

We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact [email protected].

Citation

If you find this code useful for your research, please consider citing our work (under review):

@article{rame2021ishr,
    title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization},
    author={Alexandre Rame and Corentin Dancette and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2109.02934}
}
My freqtrade strategies

My freqtrade-strategies Hi there! This is repo for my freqtrade-strategies. My name is Ilya Zelenchuk, I'm a lecturer at the SPbU university (https://

171 Dec 05, 2022
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives

Status: Under development (expect bug fixes and huge updates) ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectiv

37 Dec 28, 2022
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

John 9 Sep 18, 2022
K-FACE Analysis Project on Pytorch

Installation Setup with Conda # create a new environment conda create --name insightKface python=3.7 # or over conda activate insightKface #install t

Jung Jun Uk 7 Nov 10, 2022
An open-source, low-cost, image-based weed detection device for fallow scenarios.

Welcome to the OpenWeedLocator (OWL) project, an opensource hardware and software green-on-brown weed detector that uses entirely off-the-shelf compon

Guy Coleman 145 Jan 05, 2023
Deep Federated Learning for Autonomous Driving

FADNet: Deep Federated Learning for Autonomous Driving Abstract Autonomous driving is an active research topic in both academia and industry. However,

AIOZ AI 12 Dec 01, 2022
Husein pet projects in here!

project-suka-suka Husein pet projects in here! List of projects mysejahtera-density. Generate resolution points using meshgrid and request each points

HUSEIN ZOLKEPLI 47 Dec 09, 2022
Bayesian Generative Adversarial Networks in Tensorflow

Bayesian Generative Adversarial Networks in Tensorflow This repository contains the Tensorflow implementation of the Bayesian GAN by Yunus Saatchi and

Andrew Gordon Wilson 1k Nov 29, 2022
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
Adversarial Reweighting for Partial Domain Adaptation

Adversarial Reweighting for Partial Domain Adaptation Code for paper "Xiang Gu, Xi Yu, Yan Yang, Jian Sun, Zongben Xu, Adversarial Reweighting for Par

12 Dec 01, 2022
Justmagic - Use a function as a method with this mystic script, like in Nim

justmagic Use a function as a method with this mystic script, like in Nim. Just

witer33 8 Oct 08, 2022
PyTorch implementation of the ExORL: Exploratory Data for Offline Reinforcement Learning

ExORL: Exploratory Data for Offline Reinforcement Learning This is an original PyTorch implementation of the ExORL framework from Don't Change the Alg

Denis Yarats 52 Jan 01, 2023
PyTorch code accompanying the paper "Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning" (NeurIPS 2021).

HIGL This is a PyTorch implementation for our paper: Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning (NeurIPS 2021). Our cod

Junsu Kim 20 Dec 14, 2022
Code for the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness"

DU-VAE This is the pytorch implementation of the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness" Acknowledgement

Dazhong Shen 4 Oct 19, 2022
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
Deformable DETR is an efficient and fast-converging end-to-end object detector.

Deformable DETR: Deformable Transformers for End-to-End Object Detection.

2k Jan 05, 2023
Easy Parallel Library (EPL) is a general and efficient deep learning framework for distributed model training.

English | 简体中文 Easy Parallel Library Overview Easy Parallel Library (EPL) is a general and efficient library for distributed model training. Usability

Alibaba 185 Dec 21, 2022
Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition

Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition The official code of ABINet (CVPR 2021, Oral).

334 Dec 31, 2022
This repo contains the pytorch implementation for Dynamic Concept Learner (accepted by ICLR 2021).

DCL-PyTorch Pytorch implementation for the Dynamic Concept Learner (DCL). More details can be found at the project page. Framework Grounding Physical

Zhenfang Chen 31 Jan 06, 2023
The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

ISC-Track2-Submission The codes and related files to reproduce the results for Image Similarity Challenge Track 2. Required dependencies To begin with

Wenhao Wang 89 Jan 02, 2023