Benchmarks for semi-supervised domain generalization.

Overview

Semi-Supervised Domain Generalization

This code is the official implementation of the following paper: Semi-Supervised Domain Generalization with Stochastic StyleMatch. The paper addresses a practical and yet under-studied setting for domain generalization: one needs to use limited labeled data along with abundant unlabeled data gathered from multiple distinct domains to learn a generalizable model. This setting greatly challenges existing domain generalization methods, which are not designed to deal with unlabeled data and are thus less scalable in practice. Our approach, StyleMatch, extends the pseudo-labeling-based FixMatch—a state-of-the-art semi-supervised learning framework—in two crucial ways: 1) a stochastic classifier is designed to reduce overfitting and 2) the two-view consistency learning paradigm in FixMatch is upgraded to a multi-view version with style augmentation as the third complementary view. Two benchmarks are constructed for evaluation. Please see the paper at https://arxiv.org/abs/2106.00592 for more details.

How to setup the environment

This code is built on top of Dassl.pytorch. Please follow the instructions provided in https://github.com/KaiyangZhou/Dassl.pytorch to install the dassl environment, as well as to prepare the datasets, PACS and OfficeHome. The five random labeled-unlabeled splits can be downloaded at the following links: pacs, officehome. The splits need to be extracted to the two datasets' folders. Assume you put the datasets under the directory $DATA, the structure should look like

$DATA/
    pacs/
        images/
        splits/
        splits_ssdg/
    office_home_dg/
        art/
        clipart/
        product/
        real_world/
        splits_ssdg/

The style augmentation is based on AdaIN and the implementation is based on this code https://github.com/naoto0804/pytorch-AdaIN. Please download the weights of the decoder and the VGG from https://github.com/naoto0804/pytorch-AdaIN and put them under a new folder ssdg-benchmark/weights.

How to run StyleMatch

The script is provided in ssdg-benchmark/scripts/StyleMatch/run_ssdg.sh. You need to update the DATA variable that points to the directory where you put the datasets. There are three input arguments: DATASET, NLAB (total number of labels), and CFG. See the tables below regarding how to set the values for these variables.

Dataset NLAB
ssdg_pacs 210 or 105
ssdg_officehome 1950 or 975
CFG Description
v1 FixMatch + stochastic classifier + T_style
v2 FixMatch + stochastic classifier + T_style-only (i.e. no T_strong)
v3 FixMatch + stochastic classifier
v4 FixMatch

v1 refers to StyleMatch, which is our final model. See the config files in configs/trainers/StyleMatch for the detailed settings.

Here we give an example. Say you want to run StyleMatch on PACS under the 10-labels-per-class setting (i.e. 210 labels in total), simply run the following commands in your terminal,

conda activate dassl
cd ssdg-benchmark/scripts/StyleMatch
bash run_ssdg.sh ssdg_pacs 210 v1

In this case, the code will run StyleMatch in four different setups (four target domains), each for five times (five random seeds). You can modify the code to run a single experiment instead of all at once if you have multiple GPUs.

At the end of training, you will have

output/
    ssdg_pacs/
        nlab_210/
            StyleMatch/
                resnet18/
                    v1/ # contains results on four target domains
                        art_painting/ # contains five folders: seed1-5
                        cartoon/
                        photo/
                        sketch/

To show the results, simply do

python parse_test_res.py output/ssdg_pacs/nlab_210/StyleMatch/resnet18/v1 --multi-exp

Citation

If you use this code in your research, please cite our paper

@article{zhou2021stylematch,
    title={Semi-Supervised Domain Generalization with Stochastic StyleMatch},
    author={Zhou, Kaiyang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2106.00592},
    year={2021}
}
Owner
Kaiyang
Researcher in computer vision and machine learning :)
Kaiyang
reimpliment of DFANet: Deep Feature Aggregation for Real-Time Semantic Segmentation

DFANet This repo is an unofficial pytorch implementation of DFANet:Deep Feature Aggregation for Real-Time Semantic Segmentation log 2019.4.16 After 48

shen hui xiang 248 Oct 21, 2022
CNN Based Meta-Learning for Noisy Image Classification and Template Matching

CNN Based Meta-Learning for Noisy Image Classification and Template Matching Introduction This master thesis used a few-shot meta learning approach to

Kumar Manas 2 Dec 09, 2021
Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets).

TOQ-Nets-PyTorch-Release Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets). Temporal and Object Quantification Net

Zhezheng Luo 9 Jun 30, 2022
Action Recognition for Self-Driving Cars

Action Recognition for Self-Driving Cars This repo contains the codes for the 2021 Fall semester project "Action Recognition for Self-Driving Cars" at

VITA lab at EPFL 3 Apr 07, 2022
Pytorch Implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension)

DiffSinger - PyTorch Implementation PyTorch implementation of DiffSinger: Diffusion Acoustic Model for Singing Voice Synthesis (TTS Extension). Status

Keon Lee 152 Jan 02, 2023
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

Picsart AI Research (PAIR) 186 Dec 30, 2022
Make a surveillance camera from your raspberry pi!

rpi-surveillance Make a surveillance camera from your Raspberry Pi 4! The surveillance is built as following: the camera records 10 seconds video and

Vladyslav 62 Feb 03, 2022
A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised Learning

LABES This is the code for EMNLP 2020 paper "A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised L

17 Sep 28, 2022
Training BERT with Compute/Time (Academic) Budget

Training BERT with Compute/Time (Academic) Budget This repository contains scripts for pre-training and finetuning BERT-like models with limited time

Intel Labs 263 Jan 07, 2023
Official Implementation of SWAD (NeurIPS 2021)

SWAD: Domain Generalization by Seeking Flat Minima (NeurIPS'21) Official PyTorch implementation of SWAD: Domain Generalization by Seeking Flat Minima.

Junbum Cha 97 Dec 20, 2022
Automatic voice-synthetised summaries of latest research papers on arXiv

PaperWhisperer PaperWhisperer is a Python application that keeps you up-to-date with research papers. How? It retrieves the latest articles from arXiv

Valerio Velardo 124 Dec 20, 2022
[arXiv22] Disentangled Representation Learning for Text-Video Retrieval

Disentangled Representation Learning for Text-Video Retrieval This is a PyTorch implementation of the paper Disentangled Representation Learning for T

Qiang Wang 49 Dec 18, 2022
LiDAR R-CNN: An Efficient and Universal 3D Object Detector

LiDAR R-CNN: An Efficient and Universal 3D Object Detector Introduction This is the official code of LiDAR R-CNN: An Efficient and Universal 3D Object

TuSimple 295 Jan 05, 2023
Контрольная работа по математическим методам машинного обучения

ML-MathMethods-Test Контрольная работа по математическим методам машинного обучения. Вычисление основных статистик, диаграмм и графиков, проверка разл

Stas Ivanovskii 1 Jan 06, 2022
A Python framework for conversational search

Chatty Goose Multi-stage Conversational Passage Retrieval: An Approach to Fusing Term Importance Estimation and Neural Query Rewriting Installation Ma

Castorini 36 Oct 23, 2022
TensorFlow implementation of PHM (Parameterization of Hypercomplex Multiplication)

Parameterization of Hypercomplex Multiplications (PHM) This repository contains the TensorFlow implementation of PHM (Parameterization of Hypercomplex

Aston Zhang 9 Oct 26, 2022
Annotate datasets with a semi-trained or fully trained YOLOv5 model

YOLOv5 Auto Annotator Annotate datasets with a semi-trained or fully trained YOLOv5 model Prerequisites Ubuntu =20.04 Python =3.7 System dependencie

Akash James 3 May 14, 2022
PPO Lagrangian in JAX

PPO Lagrangian in JAX This repository implements PPO in JAX. Implementation is tested on the safety-gym benchmark. Usage Install dependencies using th

Karush Suri 2 Sep 14, 2022
Reading Group @mila-iqia on Computational Optimal Transport for Machine Learning Applications

Computational Optimal Transport for Machine Learning Reading Group Over the last few years, optimal transport (OT) has quickly become a central topic

Ali Harakeh 11 Aug 26, 2022
QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

152 Jan 02, 2023