Source code for the BMVC-2021 paper "SimReg: Regression as a Simple Yet Effective Tool for Self-supervised Knowledge Distillation".

Overview

SimReg: A Simple Regression Based Framework for Self-supervised Knowledge Distillation

Source code for the paper "SimReg: Regression as a Simple Yet Effective Tool for Self-supervised Knowledge Distillation".
Paper accepted at British Machine Vision Conference (BMVC), 2021

Overview

We present a simple framework to improve performance of regression based knowledge distillation from self-supervised teacher networks. The teacher is trained using a standard self-supervised learning (SSL) technique. The student network is then trained to directly regress the teacher features (using MSE loss on normalized features). Importantly, the student architecture contains an additional multi-layer perceptron (MLP) head atop the CNN backbone during the distillation (training) stage. A deeper architecture provides the student higher capacity to predict the teacher representations. This additional MLP head can be removed during inference without hurting downstream performance. This is especially surprising since only the output of the MLP is trained to mimic the teacher and the backbone CNN features have a high MSE loss with the teacher features. This observation allows us to obtain better student models by using deeper models during distillation without altering the inference architecture. The train and test stage architectures are shown in the figure below.

Requirements

All our experiments use the PyTorch library. We recommend installing the following package versions:

  • python=3.7.6
  • pytorch=1.4
  • torchvision=0.5.0
  • faiss-gpu=1.6.1 (required for k-NN evaluation alone)

Instructions for PyTorch installation can be found here. GPU version of the FAISS package is necessary for k-NN evaluation of trained models. It can be installed using the following command:

pip install faiss-gpu

Dataset

We use the ImageNet-1k dataset in our experiments. Download and prepare the dataset using the PyTorch ImageNet training example code. The dataset path needs to be set in the bash scripts used for training and evaluation.

Training

Distillation can be performed by running the following command:

bash run.sh

Training with ResNet-50 teacher and ResNet-18 student requires nearly 2.5 days on 4 2080ti GPUs (~26m/epoch). The defualt hyperparameters values are set to ones used in the paper. Modify the teacher and student architectures as necessary. Set the approapriate paths for the ImageNet dataset root and the experiment root. The current code will generate a directory named exp_dir containing checkpoints and logs sub-directories.

Evaluation

Set the experiment name and checkpoint epoch in the evaluation bash scripts. The trained checkpoints are assumed to be stored as exp_dir/checkpoints/ckpt_epoch_<num>.pth. Edit the weights argument to load model parameters from a custom checkpoint.

k-NN Evaluation

k-NN evaluation requires FAISS-GPU package installation. We evaluate the performance of the CNN backbone features. Run k-NN evaluation using:

bash knn_eval.sh

The image features and results for k-NN (k=1 and 20) evaluation are stored in exp_dir/features/ path.

Linear Evaluation

Here, we train a single linear layer atop the CNN backbone using an SGD optimizer for 40 epochs. The evaluation can be performed using the following code:

bash lin_eval.sh

The evaluation results are stored in exp_dir/linear/ path. Set the use_cache argument in the bash script to use cached features for evaluation. Using this argument will result in a single round of feature calculation for caching and 40 epochs of linear layer training using the cached features. While it usually results in slightly reduced performance, it can be used for faster evaluation of intermediate checkpoints.

Pretrained Models

To evaluate the pretrained models, create an experiment root directory exp_dir and place the checkpoint in exp_dir/checkpoints/. Set the exp argument in the evaluation bash scripts to perform k-NN and linear evaluation. We provide the pretrained teacher (obtained using the officially shared checkpoints for the corresponding SSL teacher) and our distilled student model weights. We use cached features of the teacher in some of our experiments for faster training.

Teacher Student 1-NN Linear
MoCo-v2 ResNet-50 MobileNet-v2 55.5 69.1
MoCo-v2 ResNet-50 ResNet-18 54.8 65.1
SimCLR ResNet-50x4 ResNet-50 (cached) 60.3 74.2
BYOL ResNet-50 ResNet-18 (cached) 56.7 66.8
SwAV ResNet-50 (cached) ResNet-18 54.0 65.8

TODO

  • Add code for transfer learning evaluation
  • Reformat evaluation codes
  • Add code to evaluate models at different stages of CNN backbone and MLP head

Citation

If you make use of the code, please cite the following work:

@inproceedings{navaneet2021simreg,
 author = {Navaneet, K L and Koohpayegani, Soroush Abbasi and Tejankar, Ajinkya and Pirsiavash, Hamed},
 booktitle = {British Machine Vision Conference (BMVC)},
 title = {SimReg: Regression as a Simple Yet Effective Tool for Self-supervised Knowledge Distillation},
 year = {2021}
}

License

This project is under the MIT license.

Learning Versatile Neural Architectures by Propagating Network Codes

Learning Versatile Neural Architectures by Propagating Network Codes Mingyu Ding, Yuqi Huo, Haoyu Lu, Linjie Yang, Zhe Wang, Zhiwu Lu, Jingdong Wang,

Mingyu Ding 36 Dec 06, 2022
Short and long time series classification using convolutional neural networks

time-series-classification Short and long time series classification via convolutional neural networks In this project, we present a novel framework f

35 Oct 22, 2022
Implementation of Artificial Neural Network Algorithm

Artificial Neural Network This repository contain implementation of Artificial Neural Network Algorithm in several programming languanges and framewor

Resha Dwika Hefni Al-Fahsi 1 Sep 14, 2022
Multiple-Object Tracking with Transformer

TransTrack: Multiple-Object Tracking with Transformer Introduction TransTrack: Multiple-Object Tracking with Transformer Models Training data Training

Peize Sun 537 Jan 04, 2023
An 16kHz implementation of HiFi-GAN for soft-vc.

HiFi-GAN An 16kHz implementation of HiFi-GAN for soft-vc. Relevant links: Official HiFi-GAN repo HiFi-GAN paper Soft-VC repo Soft-VC paper Example Usa

Benjamin van Niekerk 42 Dec 27, 2022
Learned image compression

Overview Pytorch code of our recent work A Unified End-to-End Framework for Efficient Deep Image Compression. We first release the code for Variationa

Jiaheng Liu 163 Dec 04, 2022
Basit bir burç modülü.

Bu modulu burclar hakkinda gundelik bir sekilde bilgi alin diye yaptim ve sizler icin kullanima sunuyorum. Modulun kullanimi asiri basit: Ornek Kullan

Special 17 Jun 08, 2022
Single object tracking and segmentation.

Single/Multiple Object Tracking and Segmentation Codes and comparison of recent single/multiple object tracking and segmentation. News 💥 AutoMatch is

ZP ZHANG 385 Jan 02, 2023
Checkout some cool self-projects you can try your hands on to curb your boredom this December!

SoC-Winter Checkout some cool self-projects you can try your hands on to curb your boredom this December! These are short projects that you can do you

Web and Coding Club, IIT Bombay 29 Nov 08, 2022
MG-GCN: Scalable Multi-GPU GCN Training Framework

MG-GCN MG-GCN: multi-GPU GCN training framework. For more information, please read our paper. After cloning our repository, run git submodule update -

Translational Data Analytics (TDA) Lab @GaTech 6 Oct 24, 2022
InferPy: Deep Probabilistic Modeling with Tensorflow Made Easy

InferPy: Deep Probabilistic Modeling Made Easy InferPy is a high-level API for probabilistic modeling written in Python and capable of running on top

PGM-Lab 141 Oct 13, 2022
MADE (Masked Autoencoder Density Estimation) implementation in PyTorch

pytorch-made This code is an implementation of "Masked AutoEncoder for Density Estimation" by Germain et al., 2015. The core idea is that you can turn

Andrej 498 Dec 30, 2022
[NeurIPS 2021] ORL: Unsupervised Object-Level Representation Learning from Scene Images

Unsupervised Object-Level Representation Learning from Scene Images This repository contains the official PyTorch implementation of the ORL algorithm

Jiahao Xie 55 Dec 03, 2022
This is the official pytorch implementation of AutoDebias, an automatic debiasing method for recommendation.

AutoDebias This is the official pytorch implementation of AutoDebias, a debiasing method for recommendation system. AutoDebias is proposed in the pape

Dong Hande 77 Nov 25, 2022
The best solution of the Weather Prediction track in the Yandex Shifts challenge

yandex-shifts-weather The repository contains information about my solution for the Weather Prediction track in the Yandex Shifts challenge https://re

Ivan Yu. Bondarenko 15 Dec 18, 2022
Predict Breast Cancer Wisconsin (Diagnostic) using Naive Bayes

Naive-Bayes Predict Breast Cancer Wisconsin (Diagnostic) using Naive Bayes Downloading Data Set Use our Breast Cancer Wisconsin Data Set Also you can

Faeze Habibi 0 Apr 06, 2022
Social Fabric: Tubelet Compositions for Video Relation Detection

Social-Fabric Social Fabric: Tubelet Compositions for Video Relation Detection This repository contains the code and results for the following paper:

Shuo Chen 7 Aug 09, 2022
Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets).

TOQ-Nets-PyTorch-Release Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets). Temporal and Object Quantification Net

Zhezheng Luo 9 Jun 30, 2022
FewBit — a library for memory efficient training of large neural networks

FewBit FewBit — a library for memory efficient training of large neural networks. Its efficiency originates from storage optimizations applied to back

24 Oct 22, 2022
Image Segmentation Evaluation

Image Segmentation Evaluation Martin Keršner, [email protected] Evaluation

Martin Kersner 273 Oct 28, 2022