[IJCAI-2021] A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Overview

DataFree

A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Authors: Gongfan Fang, Jie Song, Xinchao Wang, Chengchao Shen, Xingen Wang, Mingli Song

CMI (this work) DeepInv
ZSKT DFQ

Results

1. CIFAR-10

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 95.70 92.25 94.87 94.87 94.87
S. Scratch 95.20 95.20 91.12 93.94 93.95
DAFL 92.22 81.10 65.71 81.33 81.55
ZSKT 93.32 89.46 83.74 86.07 89.66
DeepInv 93.26 90.36 83.04 86.85 89.72
DFQ 94.61 90.84 86.14 91.69 92.01
CMI 94.84 91.13 90.01 92.78 92.52

2. CIFAR-100

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 78.05 71.32 75.83 75.83 75.83
S. Scratch 77.10 77.01 65.31 72.19 73.56
DAFL 74.47 57.29 22.50 34.66 40.00
ZSKT 67.74 34.72 30.15 29.73 28.44
DeepInv 61.32 54.13 53.77 61.33 61.34
DFQ 77.01 68.32 54.77 62.92 59.01
CMI 77.04 70.56 57.91 68.88 68.75

Quick Start

1. Visualize the inverted samples

Results will be saved as checkpoints/datafree-cmi/synthetic-cmi_for_vis.png

bash scripts/cmi/cmi_cifar10_for_vis.sh

2. Reproduce our results

Note: This repo was refactored from our experimental code and is still under development. I'm struggling to find the appropriate hyperparams for every methods (°ー°〃). So far, we only provide the hyperparameters to reproduce CIFAR-10 results for wrn-40-2 => wrn-16-1. You may need to tune the hyper-parameters for other models and datasets. More resources will be uploaded in the future update.

To reproduce our results, please download pre-trained teacher models from Dropbox-Models (266 MB) and extract them as checkpoints/pretrained. Also a pre-inverted data set with ~50k samples is available for wrn-40-2 teacher on CIFAR-10. You can download it from Dropbox-Data (133 MB) and extract them to run/cmi-preinverted-wrn402/.

  • Non-adversarial CMI: you can train a student model on inverted data directly. It should reach the accuracy of ~87.38% on CIFAR-10 as reported in Figure 3.

    bash scripts/cmi/nonadv_cmi_cifar10_wrn402_wrn161.sh
    
  • Adversarial CMI: or you can apply the adversarial distillation based on the pre-inverted data, where ~10k (256x40) new samples will be generated to improve the student. It should reach the accuracy of ~90.01% on CIFAR-10 as reported in Table 1.

    bash scripts/cmi/adv_cmi_cifar10_wrn402_wrn161.sh
    
  • Scratch CMI: It is OK to run the cmi algorithm wihout any pre-inverted data, but the student may overfit to early samples due to the limited data amount. It should reach the accuracy of ~88.82% on CIFAR-10, slightly worse than our reported results (90.01%).

    bash scripts/cmi/scratch_cmi_cifar10_wrn402_wrn161.sh
    

3. Scratch training

python train_scratch.py --model wrn40_2 --dataset cifar10 --batch-size 256 --lr 0.1 --epoch 200 --gpu 0

4. Vanilla KD

# KD with original training data (beta>0 to use hard targets)
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar10 --beta 0.1 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar100 --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data from a specified folder
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set run/cmi --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

5. Data-free KD

bash scripts/xxx/xxx.sh # e.g. scripts/zskt/zskt_cifar10_wrn402_wrn161.sh

Hyper-parameters used by different methods:

Method adv bn oh balance act cr GAN Example
DAFL - - - scripts/dafl_cifar10.sh
ZSKT - - - - - scripts/zskt_cifar10.sh
DeepInv - - - - scripts/deepinv_cifar10.sh
DFQ - - scripts/dfq_cifar10.sh
CMI - - scripts/cmi_cifar10_scratch.sh

4. Use your models/datasets

You can register your models and datasets in registry.py by modifying NORMALIZE_DICT, MODEL_DICT and get_dataset. Then you can run the above commands to train your own models. As DAFL requires intermediate features from the penultimate layer, your model should accept an return_features=True parameter and return a (logits, features) tuple for DAFL.

5. Implement your algorithms

Your algorithms should inherent datafree.synthesis.BaseSynthesizer to implement two interfaces: 1) BaseSynthesizer.synthesize takes several steps to craft new samples and return an image dict for visualization; 2) BaseSynthesizer.sample fetches a batch of training data for KD.

Citation

If you found this work useful for your research, please cite our paper:

@misc{fang2021contrastive,
      title={Contrastive Model Inversion for Data-Free Knowledge Distillation}, 
      author={Gongfan Fang and Jie Song and Xinchao Wang and Chengchao Shen and Xingen Wang and Mingli Song},
      year={2021},
      eprint={2105.08584},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}

Reference

Owner
ZJU-VIPA
Laboratory of Visual Intelligence and Pattern Analysis
ZJU-VIPA
YOLOV4运行在嵌入式设备上

在嵌入式设备上实现YOLO V4 tiny 在嵌入式设备上实现YOLO V4 tiny 目录结构 目录结构 |-- YOLO V4 tiny |-- .gitignore |-- LICENSE |-- README.md |-- test.txt |-- t

Liu-Wei 6 Sep 09, 2021
Implementation of RegretNet with Pytorch

Dependencies are Python 3, a recent PyTorch, numpy/scipy, tqdm, future and tensorboard. Plotting with Matplotlib. Implementation of the neural network

Horris zhGu 1 Nov 05, 2021
Robust Self-augmentation for NER with Meta-reweighting

Robust Self-augmentation for NER with Meta-reweighting

Lam chi 17 Nov 22, 2022
A Comprehensive Study on Learning-Based PE Malware Family Classification Methods

A Comprehensive Study on Learning-Based PE Malware Family Classification Methods Datasets Because of copyright issues, both the MalwareBazaar dataset

8 Oct 21, 2022
PyTorch implementation HoroPCA: Hyperbolic Dimensionality Reduction via Horospherical Projections

HoroPCA This code is the official PyTorch implementation of the ICML 2021 paper: HoroPCA: Hyperbolic Dimensionality Reduction via Horospherical Projec

HazyResearch 52 Nov 14, 2022
Trainable PyTorch reproduction of AlphaFold 2

OpenFold A faithful PyTorch reproduction of DeepMind's AlphaFold 2. Features OpenFold carefully reproduces (almost) all of the features of the origina

AQ Laboratory 1.7k Dec 29, 2022
End-to-end beat and downbeat tracking in the time domain.

WaveBeat End-to-end beat and downbeat tracking in the time domain. | Paper | Code | Video | Slides | Setup First clone the repo. git clone https://git

Christian J. Steinmetz 60 Dec 24, 2022
Run Effective Large Batch Contrastive Learning on Limited Memory GPU

Gradient Cache Gradient Cache is a simple technique for unlimitedly scaling contrastive learning batch far beyond GPU memory constraint. This means tr

Luyu Gao 198 Dec 29, 2022
Faune proche - Retrieval of Faune-France data near a google maps location

faune_proche Récupération des données de Faune-France près d'un lieu google maps

4 Feb 15, 2022
Code repository for Self-supervised Structure-sensitive Learning, CVPR'17

Self-supervised Structure-sensitive Learning (SSL) Ke Gong, Xiaodan Liang, Xiaohui Shen, Liang Lin, "Look into Person: Self-supervised Structure-sensi

Clay Gong 219 Dec 29, 2022
Galactic and gravitational dynamics in Python

Gala is a Python package for Galactic and gravitational dynamics. Documentation The documentation for Gala is hosted on Read the docs. Installation an

Adrian Price-Whelan 101 Dec 22, 2022
A platform for intelligent agent learning based on a 3D open-world FPS game developed by Inspir.AI.

Wilderness Scavenger: 3D Open-World FPS Game AI Challenge This is a platform for intelligent agent learning based on a 3D open-world FPS game develope

46 Nov 24, 2022
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

Grid AI Labs 48 Dec 12, 2022
Everything about being a TA for ITP/AP course!

تی‌ای بودن! تی‌ای یا دستیار استاد از نقش‌های رایج بین دانشجویان مهندسی است، این ریپوزیتوری قرار است نکات مهم درمورد تی‌ای بودن و تی ای شدن را به ما نش

<a href=[email protected]"> 14 Sep 10, 2022
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

BALLAD This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model. Requirements Python3 Pytorch(1.7.

peng gao 42 Nov 26, 2022
Locationinfo - A script helps the user to show network information such as ip address

Description This script helps the user to show network information such as ip ad

Roxcoder 1 Dec 30, 2021
Wind Speed Prediction using LSTMs in PyTorch

Implementation of Deep-Forecast using PyTorch Deep Forecast: Deep Learning-based Spatio-Temporal Forecasting Adapted from original implementation Setu

Onur Kaplan 151 Dec 14, 2022
Machine Learning Model deployment for Container (TensorFlow Serving)

try_tf_serving ├───dataset │ ├───testing │ │ ├───paper │ │ ├───rock │ │ └───scissors │ └───training │ ├───paper │ ├───rock

Azhar Rizki Zulma 5 Jan 07, 2022
ColossalAI-Examples - Examples of training models with hybrid parallelism using ColossalAI

ColossalAI-Examples This repository contains examples of training models with Co

HPC-AI Tech 185 Jan 09, 2023
A Python Library for Graph Outlier Detection (Anomaly Detection)

PyGOD is a Python library for graph outlier detection (anomaly detection). This exciting yet challenging field has many key applications, e.g., detect

PyGOD Team 757 Jan 04, 2023