Official implementation of paper Gradient Matching for Domain Generalization

Related tags

Deep Learningfish
Overview

Gradient Matching for Domain Generalisation

This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper, we propose an inter-domain gradient matching (IDGM) objective that targets domain generalization by maximizing the inner product between gradients from different domains. To avoid computing the expensive second-order derivative of the IDGM objective, we derive a simpler first-order algorithm named Fish that approximates its optimization.

This repository contains code to reproduce the main results of our paper.

Dependencies

(Recommended) You can setup up conda environment with all required dependencies using environment.yml:

conda env create -f environment.yml
conda activate fish

Otherwise you can also install the following packages manually:

python=3.7.10
numpy=1.20.2
pytorch=1.8.1
torchaudio=0.8.1
torchvision=0.9.1
torch-cluster=1.5.9
torch-geometric=1.7.0
torch-scatter=2.0.6
torch-sparse=0.6.9
wilds=1.1.0
scikit-learn=0.24.2
scipy=1.6.3
seaborn=0.11.1
tqdm=4.61.0

Running Experiments

We offer options to train using our proposed method Fish or by using Empirical Risk Minimisation baseline. This can be specified by the --algorithm flag (either fish or erm).

CdSprites-N

We propose this simple shape-color dataset based on the dSprites dataset, which contains a collection of white 2D sprites of different shapes, scales, rotations and positions. The dataset contains N domains, where N can be specified. The goal is to classify the shape of the sprites, and there is a shape-color deterministic matching that is specific per domain. This way we have shape as the invariant feature and color as the spurious feature. On the test set, however, this correlation between color and shape is removed. See the image below for an illustration.

cdsprites

The CdSprites-N dataset can be downloaded here. After downloading, please extract the zip file to your preferred data dir (e.g. <your_data_dir>/cdsprites). The following command runs an experiment using Fish with number of domains N=15:

python main.py --dataset cdsprites --algorithm fish --data-dir <your_data_dir> --num-domains 15

The number of domains you can choose from are: N = 5, 10, 15, 20, 25, 30, 35, 40, 45, 50.

WILDS

We include the following 6 datasets from the WILDS benchmark: amazon, camelyon, civil, fmow, iwildcam, poverty. The datasets can be downloaded automatically to a specified data folder. For instance, to train with Fish on Amazon dataset, simply run:

python main.py --dataset amazon --algorithm fish --data-dir <your_data_dir>

This should automatically download the Amazon dataset to <your_data_dir>/wilds. Experiments on other datasets can be ran by the following commands:

python main.py --dataset camelyon --algorithm fish --data-dir <your_data_dir>
python main.py --dataset civil --algorithm fish --data-dir <your_data_dir>
python main.py --dataset fmow --algorithm fish --data-dir <your_data_dir>
python main.py --dataset iwildcam --algorithm fish --data-dir <your_data_dir>
python main.py --dataset poverty --algorithm fish --data-dir <your_data_dir>

Alternatively, you can also download the datasets to <your_data_dir>/wilds manually by following the instructions here. See current results on WILDS here: image

DomainBed

For experiments on datasets including CMNIST, RMNIST, VLCS, PACS, OfficeHome, TerraInc and DomainNet, we implemented Fish on the DomainBed benchmark (see here) and you can compare our algorithm against up to 20 SOTA baselines. See current results on DomainBed here:

image

Citation

If you make use of this code in your research, we would appreciate if you considered citing the paper that is most relevant to your work:

@article{shi2021gradient,
	title="Gradient Matching for Domain Generalization.",
	author="Yuge {Shi} and Jeffrey {Seely} and Philip H. S. {Torr} and N. {Siddharth} and Awni {Hannun} and Nicolas {Usunier} and Gabriel {Synnaeve}",
	journal="arXiv preprint arXiv:2104.09937",
	year="2021"}

Contributions

We welcome contributions via pull requests. Please email [email protected] or [email protected] for any question/request.

ML-based medical imaging using Azure

Disclaimer This code is provided for research and development use only. This code is not intended for use in clinical decision-making or for any other

Microsoft Azure 68 Dec 23, 2022
Network Enhancement implementation in pytorch

network_enahncement_pytorch Network Enhancement implementation in pytorch Research paper Network Enhancement: a general method to denoise weighted bio

Yen 1 Nov 12, 2021
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

49 Nov 23, 2022
Anchor-free Oriented Proposal Generator for Object Detection

Anchor-free Oriented Proposal Generator for Object Detection Gong Cheng, Jiabao Wang, Ke Li, Xingxing Xie, Chunbo Lang, Yanqing Yao, Junwei Han, Intro

jbwang1997 56 Nov 15, 2022
Si Adek Keras is software VR dangerous object detection.

Si Adek Python Keras Sistem Informasi Deteksi Benda Berbahaya Keras Python. Version 1.0 Developed by Ananda Rauf Maududi. Developed date: 24 November

Ananda Rauf 1 Dec 21, 2021
The PyTorch implementation of paper REST: Debiased Social Recommendation via Reconstructing Exposure Strategies

REST The PyTorch implementation of paper REST: Debiased Social Recommendation via Reconstructing Exposure Strategies. Usage Download dataset Download

DMIRLAB 2 Mar 13, 2022
EgGateWayGetShell py脚本

EgGateWayGetShell_py 免责声明 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任。 使用 python3 eg.py urls.txt 目标 title:锐捷网络-EWEB网管系统 port:4430 漏洞成因 ?p

榆木 61 Nov 09, 2022
atmaCup #11 の Public 4th / Pricvate 5th Solution のリポジトリです。

#11 atmaCup 2021-07-09 ~ 2020-07-21 に行われた #11 [初心者歓迎! / 画像編] atmaCup のリポジトリです。結果は Public 4th / Private 5th でした。 フレームワークは PyTorch で、実装は pytorch-image-m

Tawara 12 Apr 07, 2022
Research code for CVPR 2021 paper "End-to-End Human Pose and Mesh Reconstruction with Transformers"

MeshTransformer ✨ This is our research code of End-to-End Human Pose and Mesh Reconstruction with Transformers. MEsh TRansfOrmer is a simple yet effec

Microsoft 473 Dec 31, 2022
PyTorch implementation of Constrained Policy Optimization

PyTorch implementation of Constrained Policy Optimization (CPO) This repository has a simple to understand and use implementation of CPO in PyTorch. A

Sapana Chaudhary 25 Dec 08, 2022
Use your Philips Hue lights as Racing Flags. Works with Assetto Corsa, Assetto Corsa Competizione and iRacing.

phue-racing-flags Use your Philips Hue lights as Racing Flags. Explore the docs » Report Bug · Request Feature Table of Contents About The Project Bui

50 Sep 03, 2022
A Lightweight Hyperparameter Optimization Tool 🚀

Lightweight Hyperparameter Optimization 🚀 The mle-hyperopt package provides a simple and intuitive API for hyperparameter optimization of your Machin

136 Jan 08, 2023
Image super-resolution (SR) is a fast-moving field with novel architectures attracting the spotlight

Revisiting RCAN: Improved Training for Image Super-Resolution Introduction Image super-resolution (SR) is a fast-moving field with novel architectures

Zudi Lin 76 Dec 01, 2022
Companion code for the paper "Meta-Learning the Search Distribution of Black-Box Random Search Based Adversarial Attacks" by Yatsura et al.

META-RS This is the companion code for the paper "Meta-Learning the Search Distribution of Black-Box Random Search Based Adversarial Attacks" by Yatsu

Bosch Research 7 Dec 09, 2022
Simple API for UCI Machine Learning Dataset Repository (search, download, analyze)

A simple API for working with University of California, Irvine (UCI) Machine Learning (ML) repository Table of Contents Introduction About Page of the

Tirthajyoti Sarkar 223 Dec 05, 2022
A small demonstration of using WebDataset with ImageNet and PyTorch Lightning

A small demonstration of using WebDataset with ImageNet and PyTorch Lightning

Tom 50 Dec 16, 2022
Instance-based label smoothing for improving deep neural networks generalization and calibration

Instance-based Label Smoothing for Neural Networks Pytorch Implementation of the algorithm. This repository includes a new proposed method for instanc

Mohamed Maher 1 Aug 13, 2022
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior. The code will release soon. Implementation Python3 PyTorch=1.0 NVIDIA GPU+

FengZhang 34 Dec 04, 2022
Training BERT with Compute/Time (Academic) Budget

Training BERT with Compute/Time (Academic) Budget This repository contains scripts for pre-training and finetuning BERT-like models with limited time

Intel Labs 263 Jan 07, 2023
Multiview Dataset Toolkit

Multiview Dataset Toolkit Using multi-view cameras is a natural way to obtain a complete point cloud. However, there is to date only one multi-view 3D

11 Dec 22, 2022