Learning to Prompt for Vision-Language Models.

Related tags

Deep LearningCoOp
Overview

CoOp

Paper: Learning to Prompt for Vision-Language Models

Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu

CoOp (Context Optimization) is a differentiable approach that focuses on continuous prompt learning to facilitate deployment of pre-trained vision language models (like CLIP) in downstream datasets.

Updates

  • 15.10.2021: We find that the best_val model and the last_step model achieve similar performance, so we set TEST.FINAL_MODEL = "last_step" for all datasets to save training time. Why we used best_val: the (tiny) validation set was designed for the linear probe approach, which requires extensive tuning for its hyperparameters, so we used the best_val model for CoOp as well for fair comparison (in this way, both approaches have access to the validation set).

  • 09.10.2021: Important changes are made to Dassl's transforms.py. Please pull the latest commits from https://github.com/KaiyangZhou/Dassl.pytorch and this repo to make sure the code works properly. In particular, 1) center_crop now becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio), and 2) for training, Resize(cfg.INPUT.SIZE) is deactivated when random_crop or random_resized_crop is used. Please read this issue on how these changes might affect the performance.

  • 18.09.2021: We have fixed an error in Dassl which could cause a training data loader to have zero length (so no training will be performed) when the dataset size is smaller than the batch size (due to drop_last=True). Please pull the latest commit for Dassl (>= 8eecc3c). This error led to lower results for CoOp in EuroSAT's 1- and 2-shot settings (others are all correct). We will update the paper on arxiv to fix this error.

How to Install

This code is built on top of the awesome toolbox Dassl.pytorch so you need to install the dassl environment first. Simply follow the instructions described here to install dassl as well as PyTorch. After that, run pip install -r requirements.txt under CoOp/ to install a few more packages required by CLIP (this should be done when dassl is activated). Then, you are ready to go.

Follow DATASETS.md to install the datasets.

How to Run

We provide the running scripts in scripts/. Make sure you change the path in DATA and run the commands under CoOp/scripts/.

Few-Shot Learning

All you need is CoOp/scripts/main.sh, which contains six input arguments.

DATASET takes as input a dataset name, like imagenet or caltech101. The valid names are the files' names in CoOp/configs/datasets/.

CFG means which config file to use, such as rn50, rn101 or vit_b32 (see CoOp/configs/trainers/CoOp/). Note that for ImageNet, we use CoOp/configs/trainers/CoOp/*_ep50.yaml for all settings (please follow the implementation details shown in the paper).

Below we provide examples on how to run CoOp on Caltech101.

CLIP + CoOp (M=16, end):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 False

CLIP + CoOp (M=16, mid):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 False
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 False
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 False
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 False
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 False

CLIP + CoOp (M=16, end, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 end 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 end 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 end 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 end 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 end 16 16 True

CLIP + CoOp (M=16, mid, CSC):

  • 1 shot: bash main.sh caltech101 rn50_ep50 middle 16 1 True
  • 2 shots: bash main.sh caltech101 rn50_ep100 middle 16 2 True
  • 4 shots: bash main.sh caltech101 rn50_ep100 middle 16 4 True
  • 8 shots: bash main.sh caltech101 rn50 middle 16 8 True
  • 16 shots: bash main.sh caltech101 rn50 middle 16 16 True

After the experiments are finished, you can use parse_test_res.py to calculate the average results instead of manually looking into the log files. Say the structure of output/ is

output
|–– caltech101/
|   |–– CoOp/
|   |   |–– rn50_16shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/
|   |   |–– rn50_8shots/
|   |   |   |–– nctx16_cscFalse_ctpend/
|   |   |   |   |–– seed1/
|   |   |   |   |–– seed2/
|   |   |   |   |–– seed3/

To calculate the average results for the folder rn50_16shots/nctx16_cscFalse_ctpend/, you can run

python parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend

Then, you will see something like this in your terminal

Parsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%.
file: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%.
===
Summary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend
* accuracy: 92.00% +- 0.15%
* error: 8.00% +- 0.15%
===

How to initialize the context tokens with pre-trained word vectors? Specify the words for the parameter TRAINER.COOP.CTX_INIT in your config file. In our paper, we use configs/trainers/rn50_ctxv1.yaml (give this file to --config-file, see scripts/main.sh), which uses "a photo of a" as the initialization words.

How to visualize nearest words for the learned context tokens? All you need is interpret_prompt.py. Say the learned tokens are saved in a/b/c/prompt_learner/model.pth.tar and you would like to see the top-3 nearest words for each token. In this case, run python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3

Robustness to Distribution Shift

To reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: imagenetv2, imagenet-sketch, imagenet-a and imagenet-r.

The command is provided in CoOp/scripts/eval.sh. The key arguments are --model-dir, --load-epoch and --eval-only. --model-dir indicates the directory where the models are saved (i.e. the entire folder containing log.txt, the tensorboard file and prompt_learner/). --load-epoch tells the code to load the model saved at a specific epoch, like --load-epoch 50 for ImageNet (see the source code for more details).

For example, to evaluate CLIP + CoOp (M=16, end) on ImageNetV2, you can do

# Don't need to use rn5_ep50 here as no training is performed
bash eval.sh imagenetv2 rn50

The default setting is SHOTS=16. Feel free to modify the script.

Again, you can use parse_test_res.py to automate the calculation of average performance. This time you should append --test-log, e.g., python parse_test_res.py directory --test-log.

Zero-Shot CLIP

See CoOp/scripts/zeroshot.sh.

Linear Probe CLIP

Please move to lpclip/.

How to Cite CoOp

If you use this code in your research, please kindly cite the following paper

@article{zhou2021coop,
    title={Learning to Prompt for Vision-Language Models},
    author={Zhou, Kaiyang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2109.01134},
    year={2021}
}
Owner
Kaiyang
Kaiyang
Code repo for realtime multi-person pose estimation in CVPR'17 (Oral)

Realtime Multi-Person Pose Estimation By Zhe Cao, Tomas Simon, Shih-En Wei, Yaser Sheikh. Introduction Code repo for winning 2016 MSCOCO Keypoints Cha

Zhe Cao 4.9k Dec 31, 2022
Learning infinite-resolution image processing with GAN and RL from unpaired image datasets, using a differentiable photo editing model.

Exposure: A White-Box Photo Post-Processing Framework ACM Transactions on Graphics (presented at SIGGRAPH 2018) Yuanming Hu1,2, Hao He1,2, Chenxi Xu1,

Yuanming Hu 719 Dec 29, 2022
This is a beginner-friendly repo to make a collection of some unique and awesome projects. Everyone in the community can benefit & get inspired by the amazing projects present over here.

Awesome-Projects-Collection Quality over Quantity :) What to do? Add some unique and amazing projects as per your favourite tech stack for the communi

Rohan Sharma 178 Jan 01, 2023
Code for the Convolutional Vision Transformer (ConViT)

ConViT : Vision Transformers with Convolutional Inductive Biases This repository contains PyTorch code for ConViT. It builds on code from the Data-Eff

Facebook Research 418 Jan 06, 2023
Weight initialization schemes for PyTorch nn.Modules

nninit Weight initialization schemes for PyTorch nn.Modules. This is a port of the popular nninit for Torch7 by @kaixhin. ##Update This repo has been

Alykhan Tejani 69 Jan 26, 2021
Training RNNs as Fast as CNNs

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
Code and data form the paper BERT Got a Date: Introducing Transformers to Temporal Tagging

BERT Got a Date: Introducing Transformers to Temporal Tagging Satya Almasian*, Dennis Aumiller*, and Michael Gertz Heidelberg University Contact us vi

54 Dec 04, 2022
Ensembling Off-the-shelf Models for GAN Training

Data-Efficient GANs with DiffAugment project | paper | datasets | video | slides Generated using only 100 images of Obama, grumpy cats, pandas, the Br

MIT HAN Lab 1.2k Dec 26, 2022
Implementation of Convolutional enhanced image Transformer

CeiT : Convolutional enhanced image Transformer This is an unofficial PyTorch implementation of Incorporating Convolution Designs into Visual Transfor

Rishikesh (ऋषिकेश) 82 Dec 13, 2022
Trying to understand alias-free-gan.

alias-free-gan-explanation Trying to understand alias-free-gan in my own way. [Chinese Version 中文版本] CC-BY-4.0 License. Tzu-Heng Lin motivation of thi

Tzu-Heng Lin 12 Mar 17, 2022
Pytorch implementation of the paper "Optimization as a Model for Few-Shot Learning"

Optimization as a Model for Few-Shot Learning This repo provides a Pytorch implementation for the Optimization as a Model for Few-Shot Learning paper.

Albert Berenguel Centeno 238 Jan 04, 2023
A generalist algorithm for cell and nucleus segmentation.

Cellpose | A generalist algorithm for cell and nucleus segmentation. Cellpose was written by Carsen Stringer and Marius Pachitariu. To learn about Cel

MouseLand 733 Dec 29, 2022
Implementation of ProteinBERT in Pytorch

ProteinBERT - Pytorch (wip) Implementation of ProteinBERT in Pytorch. Original Repository Install $ pip install protein-bert-pytorch Usage import torc

Phil Wang 92 Dec 25, 2022
Compare neural networks by their feature similarity

PyTorch Model Compare A tiny package to compare two neural networks in PyTorch. There are many ways to compare two neural networks, but one robust and

Anand Krishnamoorthy 181 Jan 04, 2023
[SIGGRAPH 2022 Journal Track] AvatarCLIP: Zero-Shot Text-Driven Generation and Animation of 3D Avatars

AvatarCLIP: Zero-Shot Text-Driven Generation and Animation of 3D Avatars Fangzhou Hong1*  Mingyuan Zhang1*  Liang Pan1  Zhongang Cai1,2,3  Lei Yang2 

Fangzhou Hong 749 Jan 04, 2023
HTSeq is a Python library to facilitate processing and analysis of data from high-throughput sequencing (HTS) experiments.

HTSeq DEVS: https://github.com/htseq/htseq DOCS: https://htseq.readthedocs.io A Python library to facilitate programmatic analysis of data from high-t

HTSeq 57 Dec 20, 2022
Self-driving car env with PPO algorithm from stable baseline3

Self-driving car with RL stable baseline3 Most of the project develop from https://github.com/GerardMaggiolino/Gym-Medium-Post Please check it out! Th

Sornsiri.P 7 Dec 22, 2022
TFOD-MASKRCNN - Tensorflow MaskRCNN With Python

Tensorflow- MaskRCNN Steps git clone https://github.com/amalaj7/TFOD-MASKRCNN.gi

Amal Ajay 2 Jan 18, 2022
A module for solving and visualizing Schrödinger equation.

qmsolve This is an attempt at making a solid, easy to use solver, capable of solving and visualize the Schrödinger equation for multiple particles, an

506 Dec 28, 2022
Code for "On Memorization in Probabilistic Deep Generative Models"

On Memorization in Probabilistic Deep Generative Models This repository contains the code necessary to reproduce the experiments in On Memorization in

The Alan Turing Institute 3 Jun 09, 2022