Official PyTorch implementation of PS-KD

Overview

LGCNS AI Research pytorch

Self-Knowledge Distillation with Progressive Refinement of Targets (PS-KD)

Accepted at ICCV 2021, oral presentation

  • Official PyTorch implementation of Self-Knowledge Distillation with Progressive Refinement of Targets (PS-KD).
    [Slides] [Paper] [Video]
  • Kyungyul Kim, ByeongMoon Ji, Doyoung Yoon and Sangheum Hwang

Abstract

The generalization capability of deep neural networks has been substantially improved by applying a wide spectrum of regularization methods, e.g., restricting function space, injecting randomness during training, augmenting data, etc. In this work, we propose a simple yet effective regularization method named progressive self-knowledge distillation (PS-KD), which progressively distills a model's own knowledge to soften hard targets (i.e., one-hot vectors) during training. Hence, it can be interpreted within a framework of knowledge distillation as a student becomes a teacher itself. Specifically, targets are adjusted adaptively by combining the ground-truth and past predictions from the model itself. Please refer to the paper for more details.

Requirements

We have tested the code on the following environments:

  • Python 3.7.7 / Pytorch (>=1.6.0) / torchvision (>=0.7.0)

Datasets

Currently, only CIFAR-100, ImageNet dataset is supported.

#) To verify the effectivness of PS-KD on Detection task and Machine translation task, we used

  • For object detection: Pascal VOC
  • For machine translation: IWSLT 15 English-German / German-English, Multi30k.
  • (Please refer to the paper for more details)

How to Run

Single-node & Multi-GPU Training

To train a single model with 1 nodes & multi-GPU, run the command as follows:

$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 0 \
                  --world_size 1 \
                  --multiprocessing_distributed True

Multi-node Training

To train a single model with 2 nodes, for instance, run the commands below in sequence:

# on the node #0
$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 0 \
                  --world_size 2 \
                  --dist_url tcp://{master_ip}:{master_port} \
                  --multiprocessing_distributed
# on the node #1
$ python3 main.py --lr 0.1 \
                  --lr_decay_schedule 150 225 \
                  --PSKD \
                  --experiments_dir '<set your own path>' \
                  --classifier_type 'ResNet18' \
                  --data_path '<root your own data path>' \
                  --data_type '<cifar100 or imagenet>' \
                  --alpha_T 0.8 \
                  --rank 1 \
                  --world_size 2 \
                  --dist_url tcp://{master_ip}:{master_port} \
                  --multiprocessing_distributed

Saving & Loading Checkpoints

Saved Filenames

  • save_dir will be automatically determined(with sequential number suffixes) unless otherwise designated.
  • Model's checkpoints are saved in ./{experiments_dir}/models/checkpoint_{epoch}.pth.
  • The best checkpoints are saved in ./{experiments_dir}/models/checkpoint_best.pth.

Loading Checkpoints (resume)

  • Pass model path as a --resume argument

Experimental Results

Performance measures

  • Top-1 Error / Top-5 Error
  • Negative Log Likelihood (NLL)
  • Expected Calibration Error (ECE)
  • Area Under the Risk-coverage Curve (AURC)

Results on CIFAR-100

Model + Method Dataset Top-1 Error Top-5 Error NLL ECE AURC
PreAct ResNet-18 (baseline) CIFAR-100 24.18 6.90 1.10 11.84 67.65
PreAct ResNet-18 + Label Smoothing CIFAR-100 20.94 6.02 0.98 10.79 57.74
PreAct ResNet-18 + CS-KD [CVPR'20] CIFAR-100 21.30 5.70 0.88 6.24 56.56
PreAct ResNet-18 + TF-KD [CVPR'20] CIFAR-100 22.88 6.01 1.05 11.96 61.77
PreAct ResNet-18 + PS-KD CIFAR-100 20.82 5.10 0.76 1.77 52.10
PreAct ResNet-101 (baseline) CIFAR-100 20.75 5.28 0.89 10.02 55.45
PreAct ResNet-101 + Label Smoothing CIFAR-100 19.84 5.07 0.93 3.43 95.76
PreAct ResNet-101 + CS-KD [CVPR'20] CIFAR-100 20.76 5.62 1.02 12.18 64.44
PreAct ResNet-101 + TF-KD [CVPR'20] CIFAR-100 20.13 5.10 0.84 6.14 58.8
PreAct ResNet-101 + PS-KD CIFAR-100 19.43 4.30 0.74 6.92 49.01
DenseNet-121 (baseline) CIFAR-100 20.05 4.99 0.82 7.34 52.21
DenseNet-121 + Label Smoothing CIFAR-100 19.80 5.46 0.92 3.76 91.06
DenseNet-121 + CS-KD [CVPR'20] CIFAR-100 20.47 6.21 1.07 13.80 73.37
DenseNet-121 + TF-KD [CVPR'20] CIFAR-100 19.88 5.10 0.85 7.33 69.23
DenseNet-121 + PS-KD CIFAR-100 18.73 3.90 0.69 3.71 45.55
ResNeXt-29 (baseline) CIFAR-100 18.65 4.47 0.74 4.17 44.27
ResNeXt-29 + Label Smoothing CIFAR-100 17.60 4.23 1.05 22.14 41.92
ResNeXt-29 + CS-KD [CVPR'20] CIFAR-100 18.26 4.37 0.80 5.95 42.11
ResNeXt-29 + TF-KD [CVPR'20] CIFAR-100 17.33 3.87 0.74 6.73 40.34
ResNeXt-29 + PS-KD CIFAR-100 17.28 3.60 0.72 9.18 40.19
PyramidNet-200 (baseline) CIFAR-100 16.80 3.69 0.73 8.04 36.95
PyramidNet-200 + Label Smoothing CIFAR-100 17.82 4.72 0.89 3.46 105.02
PyramidNet-200 + CS-KD [CVPR'20] CIFAR-100 18.31 5.70 1.17 14.70 70.05
PyramidNet-200 + TF-KD [CVPR'20] CIFAR-100 16.48 3.37 0.79 10.48 37.04
PyramidNet-200 + PS-KD CIFAR-100 15.49 3.08 0.56 1.83 32.14

Results on ImageNet

Model +Method Dataset Top-1 Error Top-5 Error NLL ECE AURC
DenseNet-264* ImageNet 22.15 6.12 -- -- --
ResNet-152 ImageNet 22.19 6.19 0.88 3.84 61.79
ResNet-152 + Label Smoothing ImageNet 21.73 5.85 0.92 3.91 68.24
ResNet-152 + CS-KD [CVPR'20] ImageNet 21.61 5.92 0.90 5.79 62.12
ResNet-152 + TF-KD [CVPR'20] ImageNet 22.76 6.43 0.91 4.70 65.28
ResNet-152 + PS-KD ImageNet 21.41 5.86 0.84 2.51 61.01

* denotes results reported in the original papers

Citation

If you find this repository useful, please consider giving a star and citation PS-KD:

@InProceedings{Kim_2021_ICCV,
    author    = {Kim, Kyungyul and Ji, ByeongMoon and Yoon, Doyoung and Hwang, Sangheum},
    title     = {Self-Knowledge Distillation With Progressive Refinement of Targets},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {6567-6576}
}

Contact for Issues

License

Copyright (c) 2021-present LG CNS Corp.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
Owner
Open source repository of LG CNS AI Research (LAIR), LG
catch-22: CAnonical Time-series CHaracteristics

catch22 - CAnonical Time-series CHaracteristics About catch22 is a collection of 22 time-series features coded in C that can be run from Python, R, Ma

Carl H Lubba 229 Oct 21, 2022
Voila - Voilà turns Jupyter notebooks into standalone web applications

Rendering of live Jupyter notebooks with interactive widgets. Introduction Voilà turns Jupyter notebooks into standalone web applications. Unlike the

Voilà Dashboards 4.5k Jan 03, 2023
Real-time 3D multi-person detection made easy with OpenPose and the ZED

OpenPose ZED This sample show how to simply use the ZED with OpenPose, the deep learning framework that detects the skeleton from a single 2D image. T

blanktec 5 Nov 06, 2020
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 04, 2023
Gesture recognition on Event Data

Event based Gesture Recognition Gesture recognition on Event Data usually involv

2 Feb 14, 2022
[CVPR'2020] DeepDeform: Learning Non-rigid RGB-D Reconstruction with Semi-supervised Data

DeepDeform (CVPR'2020) DeepDeform is an RGB-D video dataset containing over 390,000 RGB-D frames in 400 videos, with 5,533 optical and scene flow imag

Aljaz Bozic 165 Jan 09, 2023
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
Official PyTorch Implementation of Mask-aware IoU and maYOLACT Detector [BMVC2021]

The official implementation of Mask-aware IoU and maYOLACT detector. Our implementation is based on mmdetection. Mask-aware IoU for Anchor Assignment

Kemal Oksuz 46 Sep 29, 2022
Convolutional Neural Network for Text Classification in Tensorflow

This code belongs to the "Implementing a CNN for Text Classification in Tensorflow" blog post. It is slightly simplified implementation of Kim's Convo

Denny Britz 5.5k Jan 02, 2023
Blind Image Super-resolution with Elaborate Degradation Modeling on Noise and Kernel

Blind Image Super-resolution with Elaborate Degradation Modeling on Noise and Kernel This repository is the official PyTorch implementation of BSRDM w

Zongsheng Yue 69 Jan 05, 2023
Knowledge Distillation Toolbox for Semantic Segmentation

SegDistill: Toolbox for Knowledge Distillation on Semantic Segmentation Networks This repo contains the supported code and configuration files for Seg

9 Dec 12, 2022
Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES)

Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES) This repo contains the full NITRATES pipeline for maximum likelihood-driven discov

13 Nov 08, 2022
Keras implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 8.9k Jan 04, 2023
DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting

DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting Created by Yongming Rao*, Wenliang Zhao*, Guangyi Chen, Yansong Tang, Zheng Z

Yongming Rao 321 Dec 27, 2022
A parametric soroban written with CADQuery.

A parametric soroban written in CADQuery The purpose of this project is to demonstrate how "code CAD" can be intuitive to learn. See soroban.py for a

Lee 4 Aug 13, 2022
This project helps to colorize grayscale images using multiple exemplars.

Multiple Exemplar-based Deep Colorization (Pytorch Implementation) Pretrained Model [Jitendra Chautharia](IIT Jodhpur)1,3, Prerequisites Python 3.6+ N

jitendra chautharia 3 Aug 05, 2022
💡 Type hints for Numpy

Type hints with dynamic checks for Numpy! (❒) Installation pip install nptyping (❒) Usage (❒) NDArray nptyping.NDArray lets you define the shape and

Ramon Hagenaars 377 Dec 28, 2022
Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras

Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras This tutorial shows how to use Keras library to build deep ne

Marko Jocić 922 Dec 19, 2022
SwinTrack: A Simple and Strong Baseline for Transformer Tracking

SwinTrack This is the official repo for SwinTrack. A Simple and Strong Baseline Prerequisites Environment conda (recommended) conda create -y -n SwinT

LitingLin 196 Jan 04, 2023
Computer-Vision-Paper-Reviews - Computer Vision Paper Reviews with Key Summary along Papers & Codes

Computer-Vision-Paper-Reviews Computer Vision Paper Reviews with Key Summary along Papers & Codes. Jonathan Choi 2021 50+ Papers across Computer Visio

Jonathan Choi 2 Mar 17, 2022