Learning to Prompt for Vision-Language Models.

Related tags

Deep LearningCoOp
Overview

CoOp

Paper: Learning to Prompt for Vision-Language Models

Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu

CoOp (Context Optimization) is a differentiable approach that focuses on continuous prompt learning to facilitate deployment of pre-trained vision language models (like CLIP) in downstream datasets.

Updates

  • 15.10.2021: We find that the best_val model and the last_step model achieve similar performance, so we set TEST.FINAL_MODEL = "last_step" for all datasets to save training time. Why we used best_val: the (tiny) validation set was designed for the linear probe approach, which requires extensive tuning for its hyperparameters, so we used the best_val model for CoOp as well for fair comparison (in this way, both approaches have access to the validation set).

  • 09.10.2021: Important changes are made to Dassl's transforms.py. Please pull the latest commits from https://github.com/KaiyangZhou/Dassl.pytorch and this repo to make sure the code works properly. In particular, 1) center_crop now becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio), and 2) for training, Resize(cfg.INPUT.SIZE) is deactivated when random_crop or random_resized_crop is used. Please read this issue on how these changes might affect the performance.

  • 18.09.2021: We have fixed an error in Dassl which could cause a training data loader to have zero length (so no training will be performed) when the dataset size is smaller than the batch size (due to drop_last=True). Please pull the latest commit for Dassl (>= 8eecc3c). This error led to lower results for CoOp in EuroSAT's 1- and 2-shot settings (others are all correct). We will update the paper on arxiv to fix this error.

How to Install

This code is built on top of the awesome toolbox Dassl.pytorch so you need to install the dassl environment first. Simply follow the instructions described here to install dassl as well as PyTorch. After that, run pip install -r requirements.txt under CoOp/ to install a few more packages required by CLIP (this should be done when dassl is activated). Then, you are ready to go.

Follow DATASETS.md to install the datasets.

How to Run

We provide the running scripts in scripts/. Make sure you change the path in DATA and run the commands under CoOp/scripts/.

Few-Shot Learning

All you need is CoOp/scripts/main.sh, which contains six input arguments.

DATASET takes as input a dataset name, like imagenet or caltech101. The valid names are the files' names in CoOp/configs/datasets/.

CFG means which config file to use, such as rn50, rn101 or vit_b32 (see CoOp/configs/trainers/CoOp/). Note that for ImageNet, we use CoOp/configs/trainers/CoOp/*_ep50.yaml for all settings (please follow the implementation details shown in the paper).

Below we provide examples on how to run CoOp on Caltech101.

CLIP + CoOp (M=16, end):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 False

CLIP + CoOp (M=16, mid):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 False

CLIP + CoOp (M=16, end, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 True

CLIP + CoOp (M=16, mid, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 True

After the experiments are finished, you can use parse_test_res.py to calculate the average results instead of manually looking into the log files. Say the structure of output/ is

output
|–– caltech101/
|   |–– CoOp/
|   |   |–– rn50_16shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/
|   |   |–– rn50_8shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/

To calculate the average results for the folder rn50_16shots/nctx16_cscFalse_ctpend/, you can run

python parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend

Then, you will see something like this in your terminal

Parsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%.
===
Summary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
* accuracy: 92.00% +- 0.15%
* error: 8.00% +- 0.15%
===

How to initialize the context tokens with pre-trained word vectors? Specify the words for the parameter TRAINER.COOP.CTX_INIT in your config file. In our paper, we use configs/trainers/rn50_ctxv1.yaml (give this file to --config-file, see scripts/main.sh), which uses "a photo of a" as the initialization words.

How to visualize nearest words for the learned context tokens? All you need is interpret_prompt.py. Say the learned tokens are saved in a/b/c/prompt_learner/model.pth.tar and you would like to see the top-3 nearest words for each token. In this case, run python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3

Robustness to Distribution Shift

To reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: imagenetv2, imagenet-sketch, imagenet-a and imagenet-r.

The command is provided in CoOp/scripts/eval.sh. The key arguments are --model-dir, --load-epoch and --eval-only. --model-dir indicates the directory where the models are saved (i.e. the entire folder containing log.txt, the tensorboard file and prompt_learner/). --load-epoch tells the code to load the model saved at a specific epoch, like --load-epoch 50 for ImageNet (see the source code for more details).

For example, to evaluate CLIP + CoOp (M=16, end) on ImageNetV2, you can do

# Don't need to use rn5_ep50 here as no training is performed
bash eval.sh imagenetv2 rn50

The default setting is SHOTS=16. Feel free to modify the script.

Again, you can use parse_test_res.py to automate the calculation of average performance. This time you should append --test-log, e.g., python parse_test_res.py directory --test-log.

Zero-Shot CLIP

See CoOp/scripts/zeroshot.sh.

Linear Probe CLIP

Please move to lpclip/.

How to Cite CoOp

If you use this code in your research, please kindly cite the following paper

@article{zhou2021coop,
    title={Learning to Prompt for Vision-Language Models},
    author={Zhou, Kaiyang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2109.01134},
    year={2021}
}
Owner
Kaiyang
Kaiyang
One-Shot Neural Ensemble Architecture Search by Diversity-Guided Search Space Shrinking

One-Shot Neural Ensemble Architecture Search by Diversity-Guided Search Space Shrinking This is an official implementation for NEAS presented in CVPR

Multimedia Research 19 Sep 08, 2022
AI Based Smart Exam Proctoring Package

AI Based Smart Exam Proctoring Package It takes image (base64) as input: Provide Output as: Detection of Mobile phone. Detection of More than 1 person

NARENDER KESWANI 3 Sep 09, 2022
This is the official pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering" on VQA Task

🌈 ERASOR (RA-L'21 with ICRA Option) Official page of "ERASOR: Egocentric Ratio of Pseudo Occupancy-based Dynamic Object Removal for Static 3D Point C

Hyungtae Lim 225 Dec 29, 2022
[NeurIPS 2021] A weak-shot object detection approach by transferring semantic similarity and mask prior.

TransMaS This repository is the official pytorch implementation of the following paper: NIPS2021 Mixed Supervised Object Detection by TransferringMask

BCMI 49 Jul 27, 2022
Code for "Diffusion is All You Need for Learning on Surfaces"

Source code for "Diffusion is All You Need for Learning on Surfaces", by Nicholas Sharp Souhaib Attaiki Keenan Crane Maks Ovsjanikov NOTE: the linked

Nick Sharp 247 Dec 28, 2022
An unofficial styleguide and best practices summary for PyTorch

A PyTorch Tools, best practices & Styleguide This is not an official style guide for PyTorch. This document summarizes best practices from more than a

IgorSusmelj 1.5k Jan 05, 2023
Code base for NeurIPS 2021 publication titled Kernel Functional Optimisation (KFO)

KernelFunctionalOptimisation Code base for NeurIPS 2021 publication titled Kernel Functional Optimisation (KFO) We have conducted all our experiments

2 Jun 29, 2022
Brain Tumor Detection with Tensorflow Neural Networks.

Brain-Tumor-Detection A convolutional neural network model built with Tensorflow & Keras to detect brain tumor and its different variants. Data of the

404ErrorNotFound 5 Aug 23, 2022
Implementation of SETR model, Original paper: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.

SETR - Pytorch Since the original paper (Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.) has no official

zhaohu xing 112 Dec 16, 2022
Self-Regulated Learning for Egocentric Video Activity Anticipation

Self-Regulated Learning for Egocentric Video Activity Anticipation Introduction This is a Pytorch implementation of the model described in our paper:

qzhb 13 Sep 23, 2022
CNN visualization tool in TensorFlow

tf_cnnvis A blog post describing the library: https://medium.com/@falaktheoptimist/want-to-look-inside-your-cnn-we-have-just-the-right-tool-for-you-ad

InFoCusp 778 Jan 02, 2023
High frequency AI based algorithmic trading module.

Flow Flow is a high frequency algorithmic trading module that uses machine learning to self regulate and self optimize for maximum return. The current

59 Dec 14, 2022
Finding Biological Plausibility for Adversarially Robust Features via Metameric Tasks

Adversarially-Robust-Periphery Code + Data from the paper "Finding Biological Plausibility for Adversarially Robust Features via Metameric Tasks" by A

Anne Harrington 2 Feb 07, 2022
A Novel Plug-in Module for Fine-grained Visual Classification

Pytorch implementation for A Novel Plug-in Module for Fine-Grained Visual Classification. fine-grained visual classification task.

ChouPoYung 109 Dec 20, 2022
A collection of papers about Transformer in the field of medical image analysis.

A collection of papers about Transformer in the field of medical image analysis.

Junyu Chen 377 Jan 05, 2023
Code for the paper "Spatio-temporal Self-Supervised Representation Learning for 3D Point Clouds" (ICCV 2021)

Spatio-temporal Self-Supervised Representation Learning for 3D Point Clouds This is the official code implementation for the paper "Spatio-temporal Se

Hesper 63 Jan 05, 2023
Just Randoms Cats with python

Random-Cat Just Randoms Cats with python.

OriCode 2 Dec 21, 2021
Key information extraction from invoice document with Graph Convolution Network

Key Information Extraction from Scanned Invoices Key information extraction from invoice document with Graph Convolution Network Related blog post fro

Phan Hoang 39 Dec 16, 2022
EdMIPS: Rethinking Differentiable Search for Mixed-Precision Neural Networks

EdMIPS is an efficient algorithm to search the optimal mixed-precision neural network directly without proxy task on ImageNet given computation budgets. It can be applied to many popular network arch

Zhaowei Cai 47 Dec 30, 2022
Transfer Learning for Pose Estimation of Illustrated Characters

bizarre-pose-estimator Transfer Learning for Pose Estimation of Illustrated Characters Shuhong Chen *, Matthias Zwicker * WACV2022 [arxiv] [video] [po

Shuhong Chen 142 Dec 28, 2022