PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Overview

Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime

Created by Prarthana Bhattacharyya.

Disclaimer: This is not an official product and is meant to be a proof-of-concept and for academic/educational use only.

This repository contains the PyTorch implementation for the paper Visual Representation Learning with Self-Supervised Attention for Low-Label High-Data Regime, to be presented at ICASSP-2022.

Self-supervision has shown outstanding results for natural language processing, and more recently, for image recognition. Simultaneously, vision transformers and its variants have emerged as a promising and scalable alternative to convolutions on various computer vision tasks. In this paper, we are the first to question if self-supervised vision transformers (SSL-ViTs) can be adapted to two important computer vision tasks in the low-label, high-data regime: few-shot image classification and zero-shot image retrieval. The motivation is to reduce the number of manual annotations required to train a visual embedder, and to produce generalizable, semantically meaningful and robust embeddings.


Results

  • SSL-ViT + few-shot image classification:
  • Qualitative analysis for base-classes chosen by supervised CNN and SSL-ViT for few-shot distribution calibration:
  • SSL-ViT + zero-shot image retrieval:

Pretraining Self-Supervised ViT

  • Run DINO with ViT-small network on a single node with 4 GPUs for 100 epochs with the following command.
cd dino/
python -m torch.distributed.launch --nproc_per_node=4 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
  • For mini-ImageNet pretraining, we use the classes listed in: ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_mini.txt For tiered-ImageNet pretraining, we use the classes listed in: ssl-vit-fewshot/data/ImageNetSSLTrainingSplit_tiered.txt
  • For CUB-200, Cars-196 and SOP, we use the pretrained model from:
import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

Visual Representation Learning with Self-Supervised ViT for Low-Label High-Data Regime

Dataset Preparation

Please follow the instruction in FRN for few-shot image classification and RevisitDML for zero-shot image retrieval to download the datasets and put the corresponding datasets in ssl-vit-fewshot/data and DIML/data folder.

Training and Evaluation for few-shot image classification

  • The first step is to extract features for base and novel classes using the pretrained SSL-ViT.
  • get_dino_miniimagenet_feats.ipynb extracts SSL-ViT features for the base and novel classes.
  • Change the hyper-parameter data_path to use CUB or tiered-ImageNet.
  • The SSL-ViT checkpoints for the various datasets are provided below (Note: this has only been trained without labels). We also provide the extracted features which need to be stored in ssl-vit-fewshot/dino_features_data/.
arch dataset download extracted-train extracted-test
ViT-S/16 mini-ImageNet mini_imagenet_checkpoint.pth train.p test.p
ViT-S/16 tiered-ImageNet tiered_imagenet_checkpoint.pth train.p test.p
ViT-S/16 CUB cub_checkpoint.pth train.p test.p
  • For n-way-k-shot evaluation, we provide miniimagenet_evaluate_dinoDC.ipynb.

Training and Evaluation for zero-shot image retrieval

  • To train the baseline CNN models, run the scripts in DIML/scripts/baselines. The checkpoints are saved in Training_Results folder. For example:
cd DIML/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh
  • To train the supervised ViT and self-supervised ViT:
cp -r ssl-vit-retrieval/architectures/* DIML/ssl-vit-retrieval/architectures/
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch vits
CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh --arch dino
  • To test the models, first edit the checkpoint paths in test_diml.py, then run
CUDA_VISIBLE_DEVICES=0 ./scripts/diml/test_diml.sh cub200
dataset Loss SSL-ViT-download
CUB Margin cub_ssl-vit-margin.pth
CUB Proxy-NCA cub_ssl-vit-proxynca.pth
CUB Multi-Similarity cub_ssl-vit-ms.pth
Cars-196 Margin cars_ssl-vit-margin.pth
Cars-196 Proxy-NCA cars_ssl-vit-proxynca.pth
Cars-196 Multi-Similarity cars_ssl-vit-ms.pth

Acknowledgement

The code is based on:

Owner
Prarthana Bhattacharyya
Ph.D. Candidate @WISELab-UWaterloo
Prarthana Bhattacharyya
BTC-Generator - BTC Generator With Python

Что такое BTC-Generator? Это генератор чеков всеми любимого @BTC_BANKER_BOT Для

DoomGod 3 Aug 24, 2022
DEEPAGÉ: Answering Questions in Portuguese about the Brazilian Environment

DEEPAGÉ: Answering Questions in Portuguese about the Brazilian Environment This repository is related to the paper DEEPAGÉ: Answering Questions in Por

0 Dec 10, 2021
A PyTorch Implementation of FaceBoxes

FaceBoxes in PyTorch By Zisian Wong, Shifeng Zhang A PyTorch implementation of FaceBoxes: A CPU Real-time Face Detector with High Accuracy. The offici

Zi Sian Wong 797 Dec 17, 2022
An executor that performs image segmentation on fashion items

ClothingSegmenter U2NET fashion image/clothing segmenter based on https://github.com/levindabhi/cloth-segmentation Overview The ClothingSegmenter exec

Jina AI 5 Mar 30, 2022
RP-GAN: Stable GAN Training with Random Projections

RP-GAN: Stable GAN Training with Random Projections This repository contains a reference implementation of the algorithm described in the paper: Behna

Ayan Chakrabarti 20 Sep 18, 2021
Structured Data Gradient Pruning (SDGP)

Structured Data Gradient Pruning (SDGP) Weight pruning is a technique to make Deep Neural Network (DNN) inference more computationally efficient by re

Bradley McDanel 10 Nov 11, 2022
Libtorch yolov3 deepsort

Overview It is for my undergrad thesis in Tsinghua University. There are four modules in the project: Detection: YOLOv3 Tracking: SORT and DeepSORT Pr

Xu Wei 226 Dec 13, 2022
This repository provides an unified frameworks to train and test the state-of-the-art few-shot font generation (FFG) models.

FFG-benchmarks This repository provides an unified frameworks to train and test the state-of-the-art few-shot font generation (FFG) models. What is Fe

Clova AI Research 101 Dec 27, 2022
PyTorch implementation of SMODICE: Versatile Offline Imitation Learning via State Occupancy Matching

SMODICE: Versatile Offline Imitation Learning via State Occupancy Matching This is the official PyTorch implementation of SMODICE: Versatile Offline I

Jason Ma 14 Aug 30, 2022
My course projects for the 2021 Spring Machine Learning course at the National Taiwan University (NTU)

ML2021Spring There are my projects for the 2021 Spring Machine Learning course at the National Taiwan University (NTU) Course Web : https://speech.ee.

Ding-Li Chen 15 Aug 29, 2022
Code for "Retrieving Black-box Optimal Images from External Databases" (WSDM 2022)

Retrieving Black-box Optimal Images from External Databases (WSDM 2022) We propose how a user retreives an optimal image from external databases of we

joisino 5 Apr 13, 2022
A robust pointcloud registration pipeline based on correlation.

PHASER: A Robust and Correspondence-Free Global Pointcloud Registration Ubuntu 18.04+ROS Melodic: Overview Pointcloud registration using correspondenc

ETHZ ASL 101 Dec 01, 2022
The full training script for Enformer (Tensorflow Sonnet) on TPU clusters

Enformer TPU training script (wip) The full training script for Enformer (Tensorflow Sonnet) on TPU clusters, in an effort to migrate the model to pyt

Phil Wang 10 Oct 19, 2022
Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling

Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling Code for the paper: Greg Ver Steeg and Aram Galstyan. "Hamiltonian Dynamics with N

Greg Ver Steeg 25 Mar 14, 2022
code associated with ACL 2021 DExperts paper

DExperts Hi! This repository contains code for the paper DExperts: Decoding-Time Controlled Text Generation with Experts and Anti-Experts to appear at

Alisa Liu 68 Dec 15, 2022
Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet

Reproduce ResNet-v2 using MXNet Requirements Install MXNet on a machine with CUDA GPU, and it's better also installed with cuDNN v5 Please fix the ran

Wei Wu 531 Dec 04, 2022
Autoregressive Models in PyTorch.

Autoregressive This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like auto

Christoph Heindl 41 Oct 09, 2022
MetaAvatar: Learning Animatable Clothed Human Models from Few Depth Images

MetaAvatar: Learning Animatable Clothed Human Models from Few Depth Images This repository contains the implementation of our paper MetaAvatar: Learni

sfwang 96 Dec 13, 2022
A Fast Knowledge Distillation Framework for Visual Recognition

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Zhiqiang Shen 129 Dec 24, 2022
Graph Neural Networks with Keras and Tensorflow 2.

Welcome to Spektral Spektral is a Python library for graph deep learning, based on the Keras API and TensorFlow 2. The main goal of this project is to

Daniele Grattarola 2.2k Jan 08, 2023