Robust fine-tuning of zero-shot models

Related tags

Deep Learningwise-ft
Overview

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

Abstract

Large pre-trained models such as CLIP offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they also reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models. Compared to standard fine-tuning, the resulting weight-space ensembles provide large accuracy improvements out-of-distribution, while matching or improving in-distribution accuracy. On ImageNet and five derived distribution shifts, weight-space ensembles improve out-of-distribution accuracy by 2 to 10 percentage points while increasing in-distribution accuracy by nearly 1 percentage point relative to standard fine-tuning. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Compared to standard fine-tuning, weight-space ensembles for fine-tuning (WiSE-FT) improve out-of-distribution (OOD) accuracy without decreasing in-distribution (ID) performance. Top left: Zero-shot CLIP models exhibit high effective robustness and moderate in-distribution accuracy, while standard fine-tuning (end-to-end or with a linear classifier) attains higher ID accuracy and less effective robustness. Top right: Our method linearly interpolates between the zero-shot and fine-tuned models with a mixing coefficient alpha in [0,1]. Bottom: On five distribution shifts derived from ImageNet (ImageNetV2, ImageNet-R, ImageNet Sketch, ObjectNet, and ImageNet-A), WiSE-FT improves average OOD accuracy by 8.7 percentage points (pp) when fine-tuning end-to-end (+2.1 pp when fine-tuning a linear classifier) while maintaining ID accuracy.

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}
Vehicle detection using machine learning and computer vision techniques for Udacity's Self-Driving Car Engineer Nanodegree.

Vehicle Detection Video demo Overview Vehicle detection using these machine learning and computer vision techniques. Linear SVM HOG(Histogram of Orien

hata 1.1k Dec 18, 2022
GLANet - The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv

GLANet The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv Framework: visualization results: Getting Starte

stanley 29 Dec 14, 2022
A library for preparing, training, and evaluating scalable deep learning hybrid recommender systems using PyTorch.

collie_recs Collie is a library for preparing, training, and evaluating implicit deep learning hybrid recommender systems, named after the Border Coll

ShopRunner 97 Jan 03, 2023
One implementation of the paper "DMRST: A Joint Framework for Document-Level Multilingual RST Discourse Segmentation and Parsing".

Introduction One implementation of the paper "DMRST: A Joint Framework for Document-Level Multilingual RST Discourse Segmentation and Parsing". Users

seq-to-mind 18 Dec 11, 2022
Hyperbolic Image Segmentation, CVPR 2022

Hyperbolic Image Segmentation, CVPR 2022 This is the implementation of paper Hyperbolic Image Segmentation (CVPR 2022). Repository structure assets :

Mina Ghadimi Atigh 46 Dec 29, 2022
Out-of-boundary View Synthesis towards Full-frame Video Stabilization

Out-of-boundary View Synthesis towards Full-frame Video Stabilization Introduction | Update | Results Demo | Introduction This repository contains the

25 Oct 10, 2022
Example scripts for the detection of lanes using the ultra fast lane detection model in ONNX.

Example scripts for the detection of lanes using the ultra fast lane detection model in ONNX.

Ibai Gorordo 35 Sep 07, 2022
Machine Learning Toolkit for Kubernetes

Kubeflow the cloud-native platform for machine learning operations - pipelines, training and deployment. Documentation Please refer to the official do

Kubeflow 12.1k Jan 03, 2023
Extremely simple and fast extreme multi-class and multi-label classifiers.

napkinXC napkinXC is an extremely simple and fast library for extreme multi-class and multi-label classification, that focus of implementing various m

Marek Wydmuch 43 Nov 14, 2022
Implementation of Advantage-Weighted Regression: Simple and Scalable Off-Policy Reinforcement Learning

advantage-weighted-regression Implementation of Advantage-Weighted Regression: Simple and Scalable Off-Policy Reinforcement Learning, by Peng et al. (

Omar D. Domingues 1 Dec 02, 2021
Based on Yolo's low-power, ultra-lightweight universal target detection algorithm, the parameter is only 250k, and the speed of the smart phone mobile terminal can reach ~300fps+

Based on Yolo's low-power, ultra-lightweight universal target detection algorithm, the parameter is only 250k, and the speed of the smart phone mobile terminal can reach ~300fps+

567 Dec 26, 2022
GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily

GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily Abstract Graph Neural Networks (GNNs) are widely used on a

10 Dec 20, 2022
CVPR2021 Content-Aware GAN Compression

Content-Aware GAN Compression [ArXiv] Paper accepted to CVPR2021. @inproceedings{liu2021content, title = {Content-Aware GAN Compression}, auth

52 Nov 06, 2022
PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

118 Dec 15, 2022
Extracting knowledge graphs from language models as a diagnostic benchmark of model performance.

Interpreting Language Models Through Knowledge Graph Extraction Idea: How do we interpret what a language model learns at various stages of training?

EPFL Machine Learning and Optimization Laboratory 9 Oct 25, 2022
Image Segmentation using U-Net, U-Net with skip connections and M-Net architectures

Brain-Image-Segmentation Segmentation of brain tissues in MRI image has a number of applications in diagnosis, surgical planning, and treatment of bra

Angad Bajwa 8 Oct 27, 2022
CapsuleVOS: Semi-Supervised Video Object Segmentation Using Capsule Routing

CapsuleVOS This is the code for the ICCV 2019 paper CapsuleVOS: Semi-Supervised Video Object Segmentation Using Capsule Routing. Arxiv Link: https://a

53 Oct 27, 2022
Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)

Learning Structural Edits via Incremental Tree Transformations Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21) 1.

NeuLab 40 Dec 23, 2022
A Simple Example for Imitation Learning with Dataset Aggregation (DAGGER) on Torcs Env

Imitation Learning with Dataset Aggregation (DAGGER) on Torcs Env This repository implements a simple algorithm for imitation learning: DAGGER. In thi

Hao 66 Nov 23, 2022