Self-supervised learning optimally robust representations for domain generalization.

Overview

OptDom: Learning Optimal Representations for Domain Generalization

This repository contains the official implementation for Optimal Representations for Covariate Shift️. Our paper theoretically characterizes the minimal sufficient representations for optimal domain generalization (DG) under covariate shift and derives practical self-supervised learning (SSL) objectives for learning such representations.

We provide code for reproducing our main results with contribution highlights:

  • Finetuning pretrained SSL models (CLIP) to be superior robust DG models ️[minimal example]
  • A novel contrastive adversarial domain bottleneck for learning domain-invariant representations ️[implementation]

Setup

  1. Install PyTorch 1.7.1 and CLIP following the instructions.
  2. Install other packages: pip install -r requirements.txt.

Finetune & Evaluate CLIP on DomainBed

Our paper derives SSL objectives for learning optimally robust representations and gives insights into the superior robustness of CLIP (Sec 4). Here we include the code for finetuning CLIP with our proposed objectives and evaluating on the DomainBed benchmark, which reproduces our experiments in Sec 6.2.

The implementation is included in DomainBed directory which is highly based on the DomainBed repo. The CLIP based models are implemented in domainbed/clip_algorithms.py, and the domain bottlenecks are in domainbed/bottlenecks.py. The training script for finetuning CLIP with bottlenecks is domainbed/scripts/train_clip.py.

Preparation

Move to the DomainBed directory and download the datasets:

python -m domainbed.scripts.download --data_dir ./datasets/

By default, we download the datasets: PACS, VLCS, OfficeHome, TerraIncognita, DomainNet.

Launch a single run

If you want to launch a single run for debugging, run with command:

bash run_debug.sh

The key arguments include:

  • --dataset: dataset for finetuning and evaluation.
  • --algorithm: algorithms implemented with CLIP, see domainbed/clip_algorithms.py.
  • --test_envs: list of left-out environments for testing, others used for training/finetuning.
  • --hparams: JSON-serialized hyperprameter dict, see domainbed/hparams_registry.py for list of all hyperprameters.

Note that the result of a single run could be very sensitive to hyperparameters and random seed, we recommend to launch a sweep over hyperparameters and random seeds as in DomainBed.

Launch a sweep with tuning

To launch a sweep, run with command:

bash run_sweep_clip.sh

A sweep over 10 hyperparameters and 5 random seeds is launched for each dataset and algorithm. By default, the CLIP-RN50 model is used, and you can also run with other models by changing the clip_model argument, e.g., ViT-B/32 for CLIP-ViT-B/32. Also to launch a sweep, you need to select or implement a command launcher in domainbed/command_launchers.py by setting the launcher argument. If you are using slurm, we already implement a slurm launcher that you can adapt from.

After the sweep is finished, you can collect result with the notebook collect_clip_results.ipynb. Note that the results may be slightly different from the paper due to code cleaning.

(Optional) Run CAD in DomainBed setup

You can also evaluate our proposed (conditional) CAD bottleneck in the DomainBed setup where a ResNet-50 is end-to-end trained on source domains and evaluated on a left-out target domain. We include the implementation in domainbed/algorithms.py, which you can run with command:

bash run_sweep_e2e_dombed.sh

Also you can collect result with the notebook collect_e2e_results.ipynb. Note that as the claim of our paper, the algorithms in this setup lack access to the information of the target domain, so we don't expect our bottlenecks and other algorithms to necessarily outperform ERM. However, our CAD bottleneck does lead to consistent improvement surprisingly.

Finetune CLIP on LAION-400M

Coming soon!

Minimal Code for Custom Finetuning

If you want to finetune CLIP on your dataset with our bottlenecks, we provide the minimal code example:

import torch
from torch.utils.data import DataLoader, TensorDataset
import clip
from tqdm import tqdm

from domainbed import hparams_registry
from domainbed import algorithms


# 1. Determine whether you do supervised or contrastive finetuning:
#       - True: use a cross-entropy loss with a supervised dataset
#       - False: use a contrastive loss with a text-image-pair dataset
supervised_funetuning = True

if supervised_funetuning:
    loss_name = "Sup"
    dataset_name = "my suervised dataset"
else:
    loss_name = "Contrast"
    dataset_name = "my text-image pair dataset"


# 2. Determine the bottleneck you want to use with different properties
bottleneck_name = "CondCAD"  # Ent, CAD, CondCAD
algorithm_name = loss_name + "CLIPBottleneck" + bottleneck_name


# 3. Set hyperparameters, you can also change the hyperparameter dict and default values
hparams = hparams_registry.default_hparams(algorithm_name, dataset_name)


# 4. Load pretrained CLIP models
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

pretrained, preprocess = clip.load(hparams['clip_model'], device, jit=False)


# 5. Load your dataset, you  dataset should have the form:
#       - (image, label) for supervised finetuning
#       - (image, text) for contrastive finetuning
#    Remember to use the CLIP preprocessing function for image transformation,
#       and your dataset should be a list of sub-datasets from different domains (singleton for a single domain)
dataset = load_your_dataset(dataset_name, preprocess)
num_envs = len(dataset)
num_classes = dataset.num_classes  # dummy for text-image-pair dataset


# 6. Featurize your dataset with CLIP models

def get_clip_feature(clip_model, x, y):
    """Compute CLIP features"""
    with torch.no_grad():
        z = clip_model.encode_image(x).float()
        if not supervised_funetuning:  # `y` is a batch of texts that should be tokenized
            y = clip_model.encode_text(clip.tokenize(y)).float()
    return z, y

def clip_featurize_data(clip_model, dataset, device):
    """Featurize a dataset"""
    Z, Y = [], []
    for x, y in tqdm(DataLoader(dataset, batch_size=512, num_workers=4)):
        z, y = get_clip_feature(clip_model, x.to(device), y.to(device))
        Z += [z.cpu()]
        Y += [y.cpu()]
    return TensorDataset(torch.cat(Z), torch.cat(Y))

def clip_precompute_splits(clip_model, splits, device):
    _splits = []
    for ds in splits:
        _splits.append(clip_featurize_data(clip_model, ds, device))
    return _splits


dataset = clip_precompute_splits(pretrained, dataset, device)
train_loaders = [DataLoader(
    dataset=env,
    batch_size=hparams['batch_size'],
    num_workers=4)
    for i, env in enumerate(dataset)]
train_minibatches_iterator = zip(*train_loaders)
steps_per_epoch = int(min([len(env) / hparams['batch_size'] for env in dataset]))
n_steps = hparams['max_step']


# 7. Initialize the model:
algorithm_class = algorithms.get_algorithm_class(algorithm_name)
algorithm = algorithm_class(pretrained.visual.output_dim, num_classes, num_envs, hparams, pretrained, None)
algorithm.to(device)
algorithm.train()


# 8. Finetune the model:
for step in range(n_steps):
    minibatches_device = [(x.to(device), y.to(device)) for x, y in next(train_minibatches_iterator)]
    algorithm.adjust_lr(step, n_steps, steps_per_epoch)
    step_vals = algorithm.update(minibatches_device, None)

Cite

If you find this work relevant to your work, please cite our paper:

@article{ruan2021optdom,
  title={Optimal Representations for Covariate Shift},
  author={Ruan, Yangjun and  Dubois, Yann and Maddison, Chris J},
  journal={arXiv preprint arXiv:2201.00057},
  year={2022},
}

Acknowledgement

Our code is based on:

Owner
Yangjun Ruan
Ph.D. student @ UofT & Vector Previously undergrad @ ZJU
Yangjun Ruan
Pytorch implementation of CoCon: A Self-Supervised Approach for Controlled Text Generation

COCON_ICLR2021 This is our Pytorch implementation of COCON. CoCon: A Self-Supervised Approach for Controlled Text Generation (ICLR 2021) Alvin Chan, Y

alvinchangw 79 Dec 18, 2022
Curriculum Domain Adaptation for Semantic Segmentation of Urban Scenes, ICCV 2017

AdaptationSeg This is the Python reference implementation of AdaptionSeg proposed in "Curriculum Domain Adaptation for Semantic Segmentation of Urban

Yang Zhang 128 Oct 19, 2022
Implementation of hyperparameter optimization/tuning methods for machine learning & deep learning models

Hyperparameter Optimization of Machine Learning Algorithms This code provides a hyper-parameter optimization implementation for machine learning algor

Li Yang 1.1k Dec 19, 2022
A ssl analyzer which could analyzer target domain's certificate.

ssl_analyzer A ssl analyzer which could analyzer target domain's certificate. Analyze the domain name ssl certificate information according to the inp

vincent 17 Dec 12, 2022
Implement the Pareto Optimizer and pcgrad to make a self-adaptive loss for multi-task

multi-task_losses_optimizer Implement the Pareto Optimizer and pcgrad to make a self-adaptive loss for multi-task 已经实验过了,不会有cuda out of memory情况 ##Par

14 Dec 25, 2022
PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)

PSTR (CVPR2022) This code is an official implementation of "PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)". End-to-end one-step

Jiale Cao 28 Dec 13, 2022
Keyword spotting on Arm Cortex-M Microcontrollers

Keyword spotting for Microcontrollers This repository consists of the tensorflow models and training scripts used in the paper: Hello Edge: Keyword sp

Arm Software 1k Dec 30, 2022
[ACM MM2021] MGH: Metadata Guided Hypergraph Modeling for Unsupervised Person Re-identification

Introduction This project is developed based on FastReID, which is an ongoing ReID project. Projects BUC In projects/BUC, we implement AAAI 2019 paper

WuYiming 7 Apr 13, 2022
Code repository for the paper "Tracking People with 3D Representations"

Tracking People with 3D Representations Code repository for the paper "Tracking People with 3D Representations" (paper link) (project site). Jathushan

Jathushan Rajasegaran 77 Dec 03, 2022
RITA is a family of autoregressive protein models, developed by LightOn in collaboration with the OATML group at Oxford and the Debora Marks Lab at Harvard.

RITA: a Study on Scaling Up Generative Protein Sequence Models RITA is a family of autoregressive protein models, developed by a collaboration of Ligh

LightOn 69 Dec 22, 2022
Self-Supervised Vision Transformers Learn Visual Concepts in Histopathology (LMRL Workshop, NeurIPS 2021)

Self-Supervised Vision Transformers Learn Visual Concepts in Histopathology Self-Supervised Vision Transformers Learn Visual Concepts in Histopatholog

Richard Chen 95 Dec 24, 2022
Решения, подсказки, тесты и утилиты для тренировки по алгоритмам от Яндекса.

Решения и подсказки к тренировке по алгоритмам от Яндекса Что есть внутри Решения с подсказками и комментариями; рекомендую сначала смотреть md файл п

Yankovsky Andrey 50 Dec 26, 2022
Code for "NeRS: Neural Reflectance Surfaces for Sparse-View 3D Reconstruction in the Wild," in NeurIPS 2021

Code for Neural Reflectance Surfaces (NeRS) [arXiv] [Project Page] [Colab Demo] [Bibtex] This repo contains the code for NeRS: Neural Reflectance Surf

Jason Y. Zhang 234 Dec 30, 2022
This is a beginner-friendly repo to make a collection of some unique and awesome projects. Everyone in the community can benefit & get inspired by the amazing projects present over here.

Awesome-Projects-Collection Quality over Quantity :) What to do? Add some unique and amazing projects as per your favourite tech stack for the communi

Rohan Sharma 178 Jan 01, 2023
TilinGNN: Learning to Tile with Self-Supervised Graph Neural Network (SIGGRAPH 2020)

TilinGNN: Learning to Tile with Self-Supervised Graph Neural Network (SIGGRAPH 2020) About The goal of our research problem is illustrated below: give

59 Dec 09, 2022
This is an official implementation for the WTW Dataset in "Parsing Table Structures in the Wild " on table detection and table structure recognition.

WTW-Dataset This is an official implementation for the WTW Dataset in "Parsing Table Structures in the Wild " on ICCV 2021. Here, you can download the

109 Dec 29, 2022
[ICSE2020] MemLock: Memory Usage Guided Fuzzing

MemLock: Memory Usage Guided Fuzzing This repository provides the tool and the evaluation subjects for the paper "MemLock: Memory Usage Guided Fuzzing

Cheng Wen 54 Jan 07, 2023
Python suite to construct benchmark machine learning datasets from the MIMIC-III clinical database.

MIMIC-III Benchmarks Python suite to construct benchmark machine learning datasets from the MIMIC-III clinical database. Currently, the benchmark data

Chengxi Zang 6 Jan 02, 2023
Hardware accelerated, batchable and differentiable optimizers in JAX.

JAXopt Installation | Examples | References Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. Installation JAXopt can be

Google 621 Jan 08, 2023
HybridNets: End-to-End Perception Network

HybridNets: End2End Perception Network HybridNets Network Architecture. HybridNets: End-to-End Perception Network by Dat Vu, Bao Ngo, Hung Phan 📧 FPT

Thanh Dat Vu 370 Dec 29, 2022