Source code for the BMVC-2021 paper "SimReg: Regression as a Simple Yet Effective Tool for Self-supervised Knowledge Distillation".

Overview

SimReg: A Simple Regression Based Framework for Self-supervised Knowledge Distillation

Source code for the paper "SimReg: Regression as a Simple Yet Effective Tool for Self-supervised Knowledge Distillation".
Paper accepted at British Machine Vision Conference (BMVC), 2021

Overview

We present a simple framework to improve performance of regression based knowledge distillation from self-supervised teacher networks. The teacher is trained using a standard self-supervised learning (SSL) technique. The student network is then trained to directly regress the teacher features (using MSE loss on normalized features). Importantly, the student architecture contains an additional multi-layer perceptron (MLP) head atop the CNN backbone during the distillation (training) stage. A deeper architecture provides the student higher capacity to predict the teacher representations. This additional MLP head can be removed during inference without hurting downstream performance. This is especially surprising since only the output of the MLP is trained to mimic the teacher and the backbone CNN features have a high MSE loss with the teacher features. This observation allows us to obtain better student models by using deeper models during distillation without altering the inference architecture. The train and test stage architectures are shown in the figure below.

Requirements

All our experiments use the PyTorch library. We recommend installing the following package versions:

  • python=3.7.6
  • pytorch=1.4
  • torchvision=0.5.0
  • faiss-gpu=1.6.1 (required for k-NN evaluation alone)

Instructions for PyTorch installation can be found here. GPU version of the FAISS package is necessary for k-NN evaluation of trained models. It can be installed using the following command:

pip install faiss-gpu

Dataset

We use the ImageNet-1k dataset in our experiments. Download and prepare the dataset using the PyTorch ImageNet training example code. The dataset path needs to be set in the bash scripts used for training and evaluation.

Training

Distillation can be performed by running the following command:

bash run.sh

Training with ResNet-50 teacher and ResNet-18 student requires nearly 2.5 days on 4 2080ti GPUs (~26m/epoch). The defualt hyperparameters values are set to ones used in the paper. Modify the teacher and student architectures as necessary. Set the approapriate paths for the ImageNet dataset root and the experiment root. The current code will generate a directory named exp_dir containing checkpoints and logs sub-directories.

Evaluation

Set the experiment name and checkpoint epoch in the evaluation bash scripts. The trained checkpoints are assumed to be stored as exp_dir/checkpoints/ckpt_epoch_<num>.pth. Edit the weights argument to load model parameters from a custom checkpoint.

k-NN Evaluation

k-NN evaluation requires FAISS-GPU package installation. We evaluate the performance of the CNN backbone features. Run k-NN evaluation using:

bash knn_eval.sh

The image features and results for k-NN (k=1 and 20) evaluation are stored in exp_dir/features/ path.

Linear Evaluation

Here, we train a single linear layer atop the CNN backbone using an SGD optimizer for 40 epochs. The evaluation can be performed using the following code:

bash lin_eval.sh

The evaluation results are stored in exp_dir/linear/ path. Set the use_cache argument in the bash script to use cached features for evaluation. Using this argument will result in a single round of feature calculation for caching and 40 epochs of linear layer training using the cached features. While it usually results in slightly reduced performance, it can be used for faster evaluation of intermediate checkpoints.

Pretrained Models

To evaluate the pretrained models, create an experiment root directory exp_dir and place the checkpoint in exp_dir/checkpoints/. Set the exp argument in the evaluation bash scripts to perform k-NN and linear evaluation. We provide the pretrained teacher (obtained using the officially shared checkpoints for the corresponding SSL teacher) and our distilled student model weights. We use cached features of the teacher in some of our experiments for faster training.

Teacher Student 1-NN Linear
MoCo-v2 ResNet-50 MobileNet-v2 55.5 69.1
MoCo-v2 ResNet-50 ResNet-18 54.8 65.1
SimCLR ResNet-50x4 ResNet-50 (cached) 60.3 74.2
BYOL ResNet-50 ResNet-18 (cached) 56.7 66.8
SwAV ResNet-50 (cached) ResNet-18 54.0 65.8

TODO

  • Add code for transfer learning evaluation
  • Reformat evaluation codes
  • Add code to evaluate models at different stages of CNN backbone and MLP head

Citation

If you make use of the code, please cite the following work:

@inproceedings{navaneet2021simreg,
 author = {Navaneet, K L and Koohpayegani, Soroush Abbasi and Tejankar, Ajinkya and Pirsiavash, Hamed},
 booktitle = {British Machine Vision Conference (BMVC)},
 title = {SimReg: Regression as a Simple Yet Effective Tool for Self-supervised Knowledge Distillation},
 year = {2021}
}

License

This project is under the MIT license.

A PyTorch Implementation of Single Shot MultiBox Detector

SSD: Single Shot MultiBox Object Detector, in PyTorch A PyTorch implementation of Single Shot MultiBox Detector from the 2016 paper by Wei Liu, Dragom

Max deGroot 4.8k Jan 07, 2023
A python3 tool to take a 360 degree survey of the RF spectrum (hamlib + rotctld + RTL-SDR/HackRF)

RF Light House (rflh) A python script to use a rotor and a SDR device (RTL-SDR or HackRF One) to measure the RF level around and get a data set and be

Pavel Milanes (CO7WT) 11 Dec 13, 2022
Myia prototyping

Myia Myia is a new differentiable programming language. It aims to support large scale high performance computations (e.g. linear algebra) and their g

Mila 456 Nov 07, 2022
Simulation-based inference for the Galactic Center Excess

Simulation-based inference for the Galactic Center Excess Siddharth Mishra-Sharma and Kyle Cranmer Abstract The nature of the Fermi gamma-ray Galactic

Siddharth Mishra-Sharma 3 Jan 21, 2022
Gesture Volume Control v.2

Gesture volume control v.2 In this project I am going to learn how to use Gesture Control to change the volume of a computer. I first look into hand t

Pavel Dat 23 Dec 26, 2022
Degree-Quant: Quantization-Aware Training for Graph Neural Networks.

Degree-Quant This repo provides a clean re-implementation of the code associated with the paper Degree-Quant: Quantization-Aware Training for Graph Ne

35 Oct 07, 2022
a spacial-temporal pattern detection system for home automation

Argos a spacial-temporal pattern detection system for home automation. Based on OpenCV and Tensorflow, can run on raspberry pi and notify HomeAssistan

Angad Singh 133 Jan 05, 2023
cl;asification problem using classification models in supervised learning

wine-quality-predition---classification cl;asification problem using classification models in supervised learning Wine Quality Prediction Analysis - C

Vineeth Reddy Gangula 1 Jan 18, 2022
[CVPR 2021 Oral] Variational Relational Point Completion Network

VRCNet: Variational Relational Point Completion Network This repository contains the PyTorch implementation of the paper: Variational Relational Point

PL 121 Dec 12, 2022
A real-time approach for mapping all human pixels of 2D RGB images to a 3D surface-based model of the body

DensePose: Dense Human Pose Estimation In The Wild Rıza Alp Güler, Natalia Neverova, Iasonas Kokkinos [densepose.org] [arXiv] [BibTeX] Dense human pos

Meta Research 6.4k Jan 01, 2023
Keras code and weights files for popular deep learning models.

Trained image classification models for Keras THIS REPOSITORY IS DEPRECATED. USE THE MODULE keras.applications INSTEAD. Pull requests will not be revi

François Chollet 7.2k Dec 29, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 19, 2021
A Python library for differentiable optimal control on accelerators.

A Python library for differentiable optimal control on accelerators.

Google 80 Dec 21, 2022
Woosung Choi 63 Nov 14, 2022
Experiments with the Robust Binary Interval Search (RBIS) algorithm, a Query-Based prediction algorithm for the Online Search problem.

OnlineSearchRBIS Online Search with Best-Price and Query-Based Predictions This is the implementation of the Robust Binary Interval Search (RBIS) algo

S. K. 1 Apr 16, 2022
An Industrial Grade Federated Learning Framework

DOC | Quick Start | 中文 FATE (Federated AI Technology Enabler) is an open-source project initiated by Webank's AI Department to provide a secure comput

Federated AI Ecosystem 4.8k Jan 09, 2023
Disagreement-Regularized Imitation Learning

Due to a normalization bug the expert trajectories have lower performance than the rl_baseline_zoo reported experts. Please see the following link in

Kianté Brantley 25 Apr 28, 2022
Deep Multi-Magnification Network for multi-class tissue segmentation of whole slide images

Deep Multi-Magnification Network This repository provides training and inference codes for Deep Multi-Magnification Network published here. Deep Multi

Computational Pathology 12 Aug 06, 2022
A curated list of awesome Deep Learning tutorials, projects and communities.

Awesome Deep Learning Table of Contents Books Courses Videos and Lectures Papers Tutorials Researchers Websites Datasets Conferences Frameworks Tools

Christos 20k Jan 05, 2023
deep learning model with only python and numpy with test accuracy 99 % on mnist dataset and different optimization choices

deep_nn_model_with_only_python_100%_test_accuracy deep learning model with only python and numpy with test accuracy 99 % on mnist dataset and differen

0 Aug 28, 2022