Skip to content
/ optdom Public

[ICLR'22] Self-supervised learning optimally robust representations for domain shift.

License

Notifications You must be signed in to change notification settings

ryoungj/optdom

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

OptDom: Learning Optimal Representations for Domain Shift

This repository contains the official implementation for Optimal Representations for Covariate Shift️. Our paper theoretically characterizes the minimal sufficient representations for optimal domain generalization (DG) under covariate shift and derives practical self-supervised learning (SSL) objectives for learning such representations.

We provide code for reproducing our main results with contribution highlights:

  • Finetuning pretrained SSL models (CLIP) to be superior robust DG models ️[minimal example]
  • A novel contrastive adversarial domain bottleneck for learning domain-invariant representations ️[implementation]

Setup

  1. Install PyTorch 1.7.1 and CLIP following the instructions.
  2. Install other packages: pip install -r requirements.txt.

Finetune & Evaluate CLIP on DomainBed

Our paper derives SSL objectives for learning optimally robust representations and gives insights into the superior robustness of CLIP (Sec 4). Here we include the code for finetuning CLIP with our proposed objectives and evaluating on the DomainBed benchmark, which reproduces our experiments in Sec 6.2.

The implementation is included in DomainBed directory which is highly based on the DomainBed repo. The CLIP based models are implemented in domainbed/clip_algorithms.py, and the domain bottlenecks are in domainbed/bottlenecks.py. The training script for finetuning CLIP with bottlenecks is domainbed/scripts/train_clip.py.

Preparation

Move to the DomainBed directory and download the datasets:

python -m domainbed.scripts.download --data_dir ./datasets/

By default, we download the datasets: PACS, VLCS, OfficeHome, TerraIncognita, DomainNet.

Launch a single run

If you want to launch a single run for debugging, run with command:

bash run_debug.sh

The key arguments include:

  • --dataset: dataset for finetuning and evaluation.
  • --algorithm: algorithms implemented with CLIP, see domainbed/clip_algorithms.py.
  • --test_envs: list of left-out environments for testing, others used for training/finetuning.
  • --hparams: JSON-serialized hyperprameter dict, see domainbed/hparams_registry.py for list of all hyperprameters.

Note that the result of a single run could be very sensitive to hyperparameters and random seed, we recommend to launch a sweep over hyperparameters and random seeds as in DomainBed.

Launch a sweep with tuning

To launch a sweep, run with command:

bash run_sweep_clip.sh

A sweep over 10 hyperparameters and 5 random seeds is launched for each dataset and algorithm. By default, the CLIP-RN50 model is used, and you can also run with other models by changing the clip_model argument, e.g., ViT-B/32 for CLIP-ViT-B/32. Also to launch a sweep, you need to select or implement a command launcher in domainbed/command_launchers.py by setting the launcher argument. If you are using slurm, we already implement a slurm launcher that you can adapt from.

After the sweep is finished, you can collect result with the notebooks collect_clip_results.ipynb and collect_clip_results_per_setup.ipynb. Note that the results may be slightly different from the paper due to code cleaning.

(Optional) Run CAD in DomainBed setup

You can also evaluate our proposed (conditional) CAD bottleneck in the DomainBed setup where a ResNet-50 is end-to-end trained on source domains and evaluated on a left-out target domain. We include the implementation in domainbed/algorithms.py, which you can run with command:

bash run_sweep_e2e_dombed.sh

Also you can collect result with the notebook collect_e2e_results.ipynb. Note that as the claim of our paper, the algorithms in this setup lack access to the information of the target domain, so we don't expect our bottlenecks and other algorithms to necessarily outperform ERM. However, our CAD bottleneck does lead to consistent improvement surprisingly.

Finetune CLIP on LAION-400M

Our paper also shows how to learn generic robust (both task- and domain-agnostic) representations to distribution shift with text-image contrastive loss and an entropy bottleneck that requires no prior knowledge about domains. Here we include the code for finetuning CLIP with our proposed objectives on the LAION-400M dataset and evaluate on the ImageNet-Testbed and the DomainBed benchmark, which reproduces our experiments in Sec 6.3.

Preparation

  1. Move to the LAION directory.
  2. Download the dataset with 1TB of CLIP preprocessed embeddings following the instructions.
  3. Prepare the dataset (you probably need to link the directories to those with large disk space):
python scripts/prepare_laion_dataset.py --data_download_dir ./data/download --data_processed_dir ./data/processed

Run finetuning

To finetune CLIP with and without our entropy bottleneck, run with command:

bash finetune_clip_laion.sh

By default, the CLIP-ViT-B/32 model is finetuned on LAION-400M for 1 epoch.

Run evaluation

Evaluate on ImageNet-Testbed

If you want to evaluate finetuned models on ImageNet-Testbed, where a linear classifier is trained on ImageNet and evaluated on natural distribution shift datasets, run with command:

bash evaluate_imagenet_testbed.sh

You need to set your ImageNet dataset path with the imagenet_data_dir argument. Note that the finetuned models with trained ImageNet classifiers are already registered to ImageNet-Testbed in imagenet-testbed/src/models/clip_vit.py. After the evaluation is finished, you can collect results with the notebook collect_results_imagenet_testbed.ipynb.

Evaluate on DomainBed

If you want to evaluate finetuned models on DomainBed as before, run with command:

bash evaluate_dombed.sh

You need to set your DomainBed dataset path with the dombed_data_dir argument. After the evaluation is finished, you can collect results with the notebook collect_results_domainbed.ipynb.

Minimal Code for Custom Finetuning

If you want to finetune CLIP on your dataset with our bottlenecks, we provide the minimal code example:

import torch
from torch.utils.data import DataLoader, TensorDataset
import clip
from tqdm import tqdm

from domainbed import hparams_registry
from domainbed import algorithms


# 1. Determine whether you do supervised or contrastive finetuning:
#       - True: use a cross-entropy loss with a supervised dataset
#       - False: use a contrastive loss with a text-image-pair dataset
supervised_funetuning = True

if supervised_funetuning:
    loss_name = "Sup"
    dataset_name = "my suervised dataset"
else:
    loss_name = "Contrast"
    dataset_name = "my text-image pair dataset"


# 2. Determine the bottleneck you want to use with different properties
bottleneck_name = "CondCAD"  # Ent, CAD, CondCAD
algorithm_name = loss_name + "CLIPBottleneck" + bottleneck_name


# 3. Set hyperparameters, you can also change the hyperparameter dict and default values
hparams = hparams_registry.default_hparams(algorithm_name, dataset_name)


# 4. Load pretrained CLIP models
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

pretrained, preprocess = clip.load(hparams['clip_model'], device, jit=False)


# 5. Load your dataset, you  dataset should have the form:
#       - (image, label) for supervised finetuning
#       - (image, text) for contrastive finetuning
#    Remember to use the CLIP preprocessing function for image transformation,
#       and your dataset should be a list of sub-datasets from different domains (singleton for a single domain)
dataset = load_your_dataset(dataset_name, preprocess)
num_envs = len(dataset)
num_classes = dataset.num_classes  # dummy for text-image-pair dataset


# 6. Featurize your dataset with CLIP models

def get_clip_feature(clip_model, x, y):
    """Compute CLIP features"""
    with torch.no_grad():
        z = clip_model.encode_image(x).float()
        if not supervised_funetuning:  # `y` is a batch of texts that should be tokenized
            y = clip_model.encode_text(clip.tokenize(y)).float()
    return z, y

def clip_featurize_data(clip_model, dataset, device):
    """Featurize a dataset"""
    Z, Y = [], []
    for x, y in tqdm(DataLoader(dataset, batch_size=512, num_workers=4)):
        z, y = get_clip_feature(clip_model, x.to(device), y.to(device))
        Z += [z.cpu()]
        Y += [y.cpu()]
    return TensorDataset(torch.cat(Z), torch.cat(Y))

def clip_precompute_splits(clip_model, splits, device):
    _splits = []
    for ds in splits:
        _splits.append(clip_featurize_data(clip_model, ds, device))
    return _splits


dataset = clip_precompute_splits(pretrained, dataset, device)
train_loaders = [DataLoader(
    dataset=env,
    batch_size=hparams['batch_size'],
    num_workers=4)
    for i, env in enumerate(dataset)]
train_minibatches_iterator = zip(*train_loaders)
steps_per_epoch = int(min([len(env) / hparams['batch_size'] for env in dataset]))
n_steps = hparams['max_step']


# 7. Initialize the model:
algorithm_class = algorithms.get_algorithm_class(algorithm_name)
algorithm = algorithm_class(pretrained.visual.output_dim, num_classes, num_envs, hparams, pretrained, None)
algorithm.to(device)
algorithm.train()


# 8. Finetune the model:
for step in range(n_steps):
    minibatches_device = [(x.to(device), y.to(device)) for x, y in next(train_minibatches_iterator)]
    algorithm.adjust_lr(step, n_steps, steps_per_epoch)
    step_vals = algorithm.update(minibatches_device, None)

Cite

If you find this work relevant to your work, please cite our paper:

@inproceedings{ruan2022optdom,
    title={Optimal Representations for Covariate Shift},
    author={Yangjun Ruan and Yann Dubois and Chris J. Maddison},
    booktitle={International Conference on Learning Representations},
    year={2022},
    url={https://openreview.net/forum?id=Rf58LPCwJj0}
}

Acknowledgement

Our code is based on: