PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Overview

Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Created by Prarthana Bhattacharyya.

Disclaimer: This is not an official product and is meant to be a proof-of-concept and for academic/educational use only.

This repository contains the PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime, to be presented at ICASSP-2022.

Self-supervision has shown outstanding results for natural language processing, and more recently, for image recognition. Simultaneously, vision transformers and its variants have emerged as a promising and scalable alternative to convolutions on various computer vision tasks. In this paper, we are the first to question if self-supervised vision transformers (SSL-ViTs) can be adapted to two important computer vision tasks in the low-label, high-data regime: few-shot image classification and zero-shot image retrieval. The motivation is to reduce the number of manual annotations required to train a visual embedder, and to produce generalizable, semantically meaningful and robust embeddings.


Results

  • SSL-ViT + few-shot image classification:
  • Qualitative analysis for base-classes chosen by supervised CNN and SSL-ViT for few-shot distribution calibration:
  • SSL-ViT + zero-shot image retrieval:

Pretraining Self-Supervised ViT

  • Run DINO with ViT-small network on a single node with 4 GPUs for 100 epochs with the following command.
cd dino/
python -m torch.distributed.launch --nproc_per_node=4 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
  • For mini-ImageNet pretraining, we use the classes listed in: ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_mini.txt For tiered-ImageNet pretraining, we use the classes listed in: ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_tiered.txt
  • For CUB-200, Cars-196 and SOP, we use the pretrained model from:
import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

Visual Representation Learning with Self-Supervised ViT for Low-Label High-Data Regime

Dataset Preparation

Please follow the instruction in FRN for few-shot image classification and RevisitDML for zero-shot image retrieval to download the datasets and put the corresponding datasets in ssl-vit-fewshot/data and DIML/data folder.

Training and Evaluation for few-shot image classification

  • The first step is to extract features for base and novel classes using the pretrained SSL-ViT.
  • get_dino_miniimagenet_feats.ipynb extracts SSL-ViT features for the base and novel classes.
  • Change the hyper-parameter data_path to use CUB or tiered-ImageNet.
  • The SSL-ViT checkpoints for the various datasets are provided below (Note: this has only been trained without labels). We also provide the extracted features which need to be stored in ssl-vit-fewshot/dino_features_data/.
arch dataset download extracted-train extracted-test
ViT-S/16 mini-ImageNet mini_imagenet_checkpoint.pth train.p test.p
ViT-S/16 tiered-ImageNet tiered_imagenet_checkpoint.pth train.p test.p
ViT-S/16 CUB cub_checkpoint.pth train.p test.p
  • For n-way-k-shot evaluation, we provide miniimagenet_evaluate_dinoDC.ipynb.

Training and Evaluation for zero-shot image retrieval

  • To train the baseline CNN models, run the scripts in DIML/scripts/baselines. The checkpoints are saved in Training_Results folder. For example:
cd DIML/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh
  • To train the supervised ViT and self-supervised ViT:
cp -r ssl-vit-retrieval/architectures/* DIML/ssl-vit-retrieval/architectures/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch vits
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch dino
  • To test the models, first edit the checkpoint paths in test_diml.py, then run
CUDA_VISIBLE_DEVICES=0 ./scripts/diml/test_diml.sh cub200
dataset Loss SSL-ViT-download
CUB Margin cub_ssl-vit-margin.pth
CUB Proxy-NCA cub_ssl-vit-proxynca.pth
CUB Multi-Similarity cub_ssl-vit-ms.pth
Cars-196 Margin cars_ssl-vit-margin.pth
Cars-196 Proxy-NCA cars_ssl-vit-proxynca.pth
Cars-196 Multi-Similarity cars_ssl-vit-ms.pth

Acknowledgement

The code is based on:

Owner
Prarthana Bhattacharyya
Ph.D. Candidate @WISELab-UWaterloo
Prarthana Bhattacharyya
Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweeper.

Minesweeper-AI Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweep

Beckham 0 Jul 20, 2022
Pretrained language model and its related optimization techniques developed by Huawei Noah's Ark Lab.

Pretrained Language Model This repository provides the latest pretrained language models and its related optimization techniques developed by Huawei N

HUAWEI Noah's Ark Lab 2.6k Jan 01, 2023
SeqTR: A Simple yet Universal Network for Visual Grounding

SeqTR This is the official implementation of SeqTR: A Simple yet Universal Network for Visual Grounding, which simplifies and unifies the modelling fo

seanZhuh 76 Dec 24, 2022
Implementation of Wasserstein adversarial attacks.

Stronger and Faster Wasserstein Adversarial Attacks Code for Stronger and Faster Wasserstein Adversarial Attacks, appeared in ICML 2020. This reposito

21 Oct 06, 2022
Unofficial implementation of HiFi-GAN+ from the paper "Bandwidth Extension is All You Need" by Su, et al.

HiFi-GAN+ This project is an unoffical implementation of the HiFi-GAN+ model for audio bandwidth extension, from the paper Bandwidth Extension is All

Brent M. Spell 134 Dec 30, 2022
Code for Recurrent Mask Refinement for Few-Shot Medical Image Segmentation (ICCV 2021).

Recurrent Mask Refinement for Few-Shot Medical Image Segmentation Steps Install any missing packages using pip or conda Preprocess each dataset using

XIE LAB @ UCI 39 Dec 08, 2022
Code for Robust Contrastive Learning against Noisy Views

Robust Contrastive Learning against Noisy Views This repository provides a PyTorch implementation of the Robust InfoNCE loss proposed in paper Robust

Ching-Yao Chuang 53 Jan 08, 2023
Minimal fastai code needed for working with pytorch

fastai_minima A mimal version of fastai with the barebones needed to work with Pytorch #all_slow Install pip install fastai_minima How to use This lib

Zachary Mueller 14 Oct 21, 2022
Agent-based model simulator for air quality and pandemic risk assessment in architectural spaces

Agent-based model simulation for air quality and pandemic risk assessment in architectural spaces. User Guide archABM is a fast and open source agent-

Vicomtech 10 Dec 05, 2022
Code of the paper "Shaping Visual Representations with Attributes for Few-Shot Learning (ASL)".

Shaping Visual Representations with Attributes for Few-Shot Learning This code implements the Shaping Visual Representations with Attributes for Few-S

chx_nju 9 Sep 01, 2022
Meshed-Memory Transformer for Image Captioning. CVPR 2020

M²: Meshed-Memory Transformer This repository contains the reference code for the paper Meshed-Memory Transformer for Image Captioning (CVPR 2020). Pl

AImageLab 422 Dec 28, 2022
Rainbow is all you need! A step-by-step tutorial from DQN to Rainbow

Do you want a RL agent nicely moving on Atari? Rainbow is all you need! This is a step-by-step tutorial from DQN to Rainbow. Every chapter contains bo

Jinwoo Park (Curt) 1.4k Dec 29, 2022
Code release for "Making a Bird AI Expert Work for You and Me".

Making-a-Bird-AI-Expert-Work-for-You-and-Me Code release for "Making a Bird AI Expert Work for You and Me". arxiv (Coming soon...) Changelog 2021/12/6

PRIS-CV: Computer Vision Group 11 Dec 11, 2022
Deep Hedging Demo - An Example of Using Machine Learning for Derivative Pricing.

Deep Hedging Demo Pricing Derivatives using Machine Learning 1) Jupyter version: Run ./colab/deep_hedging_colab.ipynb on Colab. 2) Gui version: Run py

Yu Man Tam 102 Jan 06, 2023
Semi-Supervised Learning, Object Detection, ICCV2021

End-to-End Semi-Supervised Object Detection with Soft Teacher By Mengde Xu*, Zheng Zhang*, Han Hu, Jianfeng Wang, Lijuan Wang, Fangyun Wei, Xiang Bai,

Microsoft 789 Dec 27, 2022
The codebase for our paper "Generative Occupancy Fields for 3D Surface-Aware Image Synthesis" (NeurIPS 2021)

Generative Occupancy Fields for 3D Surface-Aware Image Synthesis (NeurIPS 2021) Project Page | Paper Xudong Xu, Xingang Pan, Dahua Lin and Bo Dai GOF

xuxudong 97 Nov 10, 2022
High-performance moving least squares material point method (MLS-MPM) solver.

High-Performance MLS-MPM Solver with Cutting and Coupling (CPIC) (MIT License) A Moving Least Squares Material Point Method with Displacement Disconti

Yuanming Hu 2.2k Dec 31, 2022
Official implementation of Monocular Quasi-Dense 3D Object Tracking

Monocular Quasi-Dense 3D Object Tracking Monocular Quasi-Dense 3D Object Tracking (QD-3DT) is an online framework detects and tracks objects in 3D usi

Visual Intelligence and Systems Group 441 Dec 20, 2022
Official code release for "Learned Spatial Representations for Few-shot Talking-Head Synthesis" ICCV 2021

Official code release for "Learned Spatial Representations for Few-shot Talking-Head Synthesis" ICCV 2021

Moustafa Meshry 16 Oct 05, 2022
Monocular 3D pose estimation. OpenVINO. CPU inference or iGPU (OpenCL) inference.

human-pose-estimation-3d-python-cpp RealSenseD435 (RGB) 480x640 + CPU Corei9 45 FPS (Depth is not used) 1. Run 1-1. RealSenseD435 (RGB) 480x640 + CPU

Katsuya Hyodo 8 Oct 03, 2022