Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Overview

Attention Probe: Vision Transformer Distillation in the Wild

License: MIT

Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang
In ICASSP 2022

This code is the Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Overview

  • We propose the concept of Attention Probe, a special section of the attention map to utilize a large amount of unlabeled data in the wild to complete the vision transformer data-free distillation task. Instead of generating images from the teacher network with a series of priori, images most relevant to the given pre-trained network and tasks will be identified from a large unlabeled dataset (e.g., Flickr) to conduct the knowledge distillation task.
  • We propose a simple yet efficient distillation algorithm, called probe distillation, to distill the student model using intermediate features of the teacher model, which is based on the Attention Probe.

Prerequisite

We use Pytorch 1.7.1, and CUDA 11.0. You can install them with

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

It should also be applicable to other Pytorch and CUDA versions.

Usage

Data Preparation

First, you need to modify the storage format of the cifar-10/100 and tinyimagenet dataset to the style of ImageNet, etc. CIFAR 10 run:

python process_cifar10.py

CIFAR 100 run:

python process_cifar100.py

Tiny-ImageNet run:

python process_tinyimagenet.py
python process_move_file.py

The dataset dir should have the following structure:

dir/
  train/
    ...
  val/
    n01440764/
      ILSVRC2012_val_00000293.JPEG
      ...
    ...

Train a normal teacher network

For this step you need to train normal teacher transformer models for selecting valuable data from the wild. We train the teacher model based on the timm PyTorch library:

timm

Our pretrained teacher models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here:

Pretrained teacher models

Select valuable data from the wild

Then, you can use the Attention Probe method to select valuable data in the wild dataset.

To select valuable data CIFAR-10 run:

bash training.sh
(CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

CIFAR-100 run:

bash training.sh
(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

TinyImageNet run:

bash training_tinyimagenet.sh

ImageNet run:

bash training_imagenet.sh

After you will get "class_weights.pth, pred_out.pth, value_blk3.pth, value_blk7.pth, value_out.pth" in '/selected/cifar10/' or '/selected/cifar100/' directory, you have already obtained the selected data.

Probe Knowledge Distillation for Student networks

Then you can distill the student model using intermediate features of the teacher model based on the selected data.

bash training.sh
(CIFAR 10 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0,1,2,3 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

TinyImageNet run:

bash training_tinyimagenet.sh

ImageNet run:

bash training_imagenet.sh

you will get the student transformer model in '/output/cifar10/student/' or '/output/cifar100/student/' directory.

Our distilled student models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here: Distilled student models

Results

Citation

@inproceedings{
wang2022attention,
title={Attention Probe: Vision Transformer Distillation in the Wild},
author={Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang},
booktitle={International Conference on Acoustics, Speech and Signal Processing},
year={2022},
url={https://2022.ieeeicassp.org/}
}

Acknowledgement

Owner
IIGROUP
The Intelligent Interaction Group at Tsinghua University
IIGROUP
Source Code For Template-Based Named Entity Recognition Using BART

Template-Based NER Source Code For Template-Based Named Entity Recognition Using BART Training Training train.py Inference inference.py Corpus ATIS (h

174 Dec 19, 2022
PyTorch Implementation of Small Lesion Segmentation in Brain MRIs with Subpixel Embedding (ORAL, MICCAIW 2021)

Small Lesion Segmentation in Brain MRIs with Subpixel Embedding PyTorch implementation of Small Lesion Segmentation in Brain MRIs with Subpixel Embedd

22 Oct 21, 2022
SymPy-powered, Wolfram|Alpha-like answer engine totally in your browser, without backend computation

SymPy Beta SymPy Beta is a fork of SymPy Gamma. The purpose of this project is to run a SymPy-powered, Wolfram|Alpha-like answer engine totally in you

Liumeo 25 Dec 21, 2022
A pyparsing-based library for parsing SOQL statements

CONTRIBUTORS WANTED!! Installation pip install python-soql-parser or, with poetry poetry add python-soql-parser Usage from python_soql_parser import p

Kicksaw 0 Jun 07, 2022
dataset for ECCV 2020 "Motion Capture from Internet Videos"

Motion Capture from Internet Videos Motion Capture from Internet Videos Junting Dong*, Qing Shuai*, Yuanqing Zhang, Xian Liu, Xiaowei Zhou, Hujun Bao

ZJU3DV 98 Dec 07, 2022
Project to create an open-source 6 DoF input device

6DInputs A Project to create open-source 3D printed 6 DoF input devices Note the plural ('6DInputs' and 'devices') in the headings. We would like seve

RepRap Ltd 47 Jul 28, 2022
Torchreid: Deep learning person re-identification in PyTorch.

Torchreid Torchreid is a library for deep-learning person re-identification, written in PyTorch. It features: multi-GPU training support both image- a

Kaiyang 3.7k Jan 05, 2023
Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents

DeepXML Code for DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents Architectures and algorithms DeepXML supports

Extreme Classification 49 Nov 06, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
"Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices", official implementation

Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices This repository contains the official PyTorch implemen

Yandex Research 21 Oct 18, 2022
Code for "Long Range Probabilistic Forecasting in Time-Series using High Order Statistics"

Long Range Probabilistic Forecasting in Time-Series using High Order Statistics This is the code produced as part of the paper Long Range Probabilisti

16 Dec 06, 2022
BEAMetrics: Benchmark to Evaluate Automatic Metrics in Natural Language Generation

BEAMetrics: Benchmark to Evaluate Automatic Metrics in Natural Language Generation Installing The Dependencies $ conda create --name beametrics python

7 Jul 04, 2022
L-Verse: Bidirectional Generation Between Image and Text

Far beyond learning long-range interactions of natural language, transformers are becoming the de-facto standard for many vision tasks with their power and scalabilty

Kim, Taehoon 102 Dec 21, 2022
Pull sensitive data from users on windows including discord tokens and chrome data.

⭐ For a 🍪 Pegasus Pull sensitive data from users on windows including discord tokens and chrome data. Features 🟩 Discord tokens 🟩 Geolocation data

Addi 44 Dec 31, 2022
PyTorch implementation of Barlow Twins.

Barlow Twins: Self-Supervised Learning via Redundancy Reduction PyTorch implementation of Barlow Twins. @article{zbontar2021barlow, title={Barlow Tw

Facebook Research 839 Dec 29, 2022
A simple, clean TensorFlow implementation of Generative Adversarial Networks with a focus on modeling illustrations.

IllustrationGAN A simple, clean TensorFlow implementation of Generative Adversarial Networks with a focus on modeling illustrations. Generated Images

268 Nov 27, 2022
MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Soroush Omranpour 1 Jan 01, 2022
A deep neural networks for images using CNN algorithm.

Example-CNN-Project This is a simple project showing how to implement deep neural networks using CNN algorithm. The dataset is taken from this link: h

Mohammad Amin Dadgar 3 Sep 16, 2022
Code for Understanding Pooling in Graph Neural Networks

Select, Reduce, Connect This repository contains the code used for the experiments of: "Understanding Pooling in Graph Neural Networks" Setup Install

Daniele Grattarola 37 Dec 13, 2022
Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative Adversarial Neural Networks

ForecastingNonverbalSignals This is the implementation for the paper Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative A

1 Feb 10, 2022