Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

Overview

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

This repository is official Tensorflow implementation of paper:

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning [paper link]

and Tensorflow 2 example code for
   "Custom layers", "Custom training loop", "XLA (JIT)-compiling", "Distributed learing", and "Gradients accumulator".

Paper abstract

Conventional NAS-based pruning algorithms aim to find the sub-network with the best validation performance. However, validation performance does not successfully represent test performance, i.e., potential performance. Also, although fine-tuning the pruned network to restore the performance drop is an inevitable process, few studies have handled this issue. This paper proposes a novel sub-network search and fine-tuning method, i.e., Ensemble Knowledge Guidance (EKG). First, we experimentally prove that the fluctuation of the loss landscape is an effective metric to evaluate the potential performance. In order to search a sub-network with the smoothest loss landscape at a low cost, we propose a pseudo-supernet built by an ensemble sub-network knowledge distillation. Next, we propose a novel fine-tuning that re-uses the information of the search phase. We store the interim sub-networks, that is, the by-products of the search phase, and transfer their knowledge into the pruned network. Note that EKG is easy to be plugged-in and computationally efficient. For example, in the case of ResNet-50, about 45% of FLOPS is removed without any performance drop in only 315 GPU hours.


Conceptual visualization of the goal of the proposed method.

Contribution points and key features

  • As a new tool to measure the potential performance of sub-network in NAS-based pruning, the smoothness of the loss landscape is presented. Also, the experimental evidence that the loss landscape fluctuation has a higher correlation with the test performance than the validation performance is provided.
  • The pseudo-supernet based on an ensemble sub-network knowledge distillation is proposed to find a sub-network of smoother loss landscape without increasing complexity. It helps NAS-based pruning to prune all pre-trained networks, and also allows to find optimal sub-network(s) more accurately.
  • To our knowledge, this paper provides the world-first approach to store the information of the search phase in a memory bank and to reuse it in the fine-tuning phase of the pruned network. The proposed memory bank contributes to greatly improving the performance of the pruned network.

Requirement

  • Tensorflow >= 2.7 (I have tested on 2.7-2.8)
  • Pickle
  • tqdm

How to run

  1. Move to the codebase.
  2. Train and evaluate our model by the below command.
  # ResNet-56 on CIFAR10
  python train_cifar.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --search_target_rate 0.45 --train_path ../test
  python test.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --trained_param ../test/trained_param.pkl

Experimental results


(Left) Potential performance vs. validation loss (right) Potential performance vs. condition number. 50 sub-networks of ResNet-56 trained on CIFAR10 were used for this experiment. accurately.


Visualization of loss landscapes of sub-networks searched by various filter importance scoring algorithms.

Comparison with various pruning techniques for ResNet family trained on ImageNet.


Performance analysis in case of ResNet-50 trained on ImageNet-2012. The left plot is the FLOPs reduction rate-Top-1 accuracy, and the right plot is the GPU hours-Top-1 accuracy.

Reference

@article{lee2022ensemble,
  title        = {Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning},
  author       = {Seunghyun Lee, Byung Cheol Song},
  year         = 2022,
  journal      = {arXiv preprint arXiv:2203.02651}
}

Owner
Seunghyun Lee
Knowledge distillation; Neural network light-weighting; Tensorflow
Seunghyun Lee
MASS (Mueen's Algorithm for Similarity Search) - a python 2 and 3 compatible library used for searching time series sub-sequences under z-normalized Euclidean distance for similarity.

Introduction MASS allows you to search a time series for a subquery resulting in an array of distances. These array of distances enable you to identif

Matrix Profile Foundation 79 Dec 31, 2022
This was initially the repo for the project of [email protected] of Asaf Mazar, Millad Kassaie and Georgios Chochlakis named "Powered by the Will? Exploring Lay Theories of Behavior Change through Social Media"

Subreddit Analysis This repo includes tools for Subreddit analysis, originally developed for our class project of PSYC 626 in USC, titled "Powered by

Georgios Chochlakis 1 Dec 17, 2021
Implementation of Deep Deterministic Policy Gradiet Algorithm in Tensorflow

ddpg-aigym Deep Deterministic Policy Gradient Implementation of Deep Deterministic Policy Gradiet Algorithm (Lillicrap et al.arXiv:1509.02971.) in Ten

Steven Spielberg P 247 Dec 07, 2022
Machine Learning Model deployment for Container (TensorFlow Serving)

try_tf_serving ├───dataset │ ├───testing │ │ ├───paper │ │ ├───rock │ │ └───scissors │ └───training │ ├───paper │ ├───rock

Azhar Rizki Zulma 5 Jan 07, 2022
Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.

EfficientZero (NeurIPS 2021) Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021. Thank you for you

Weirui Ye 671 Jan 03, 2023
A Python Library for Graph Outlier Detection (Anomaly Detection)

PyGOD is a Python library for graph outlier detection (anomaly detection). This exciting yet challenging field has many key applications, e.g., detect

PyGOD Team 757 Jan 04, 2023
Highly comparative time-series analysis

〰️ hctsa 〰️ : highly comparative time-series analysis hctsa is a software package for running highly comparative time-series analysis using Matlab (fu

Ben Fulcher 569 Dec 21, 2022
AAAI 2022: Stationary diffusion state neural estimation

Stationary Diffusion State Neural Estimation Although many graph-based clustering methods attempt to model the stationary diffusion state in their obj

绽琨 33 Nov 24, 2022
Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning" (AAAI 2021)

Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic

NAVER/LINE Vision 30 Dec 06, 2022
Implementation of "Large Steps in Inverse Rendering of Geometry"

Large Steps in Inverse Rendering of Geometry ACM Transactions on Graphics (Proceedings of SIGGRAPH Asia), December 2021. Baptiste Nicolet · Alec Jacob

RGL: Realistic Graphics Lab 274 Jan 06, 2023
This respository includes implementations on Manifoldron: Direct Space Partition via Manifold Discovery

Manifoldron: Direct Space Partition via Manifold Discovery This respository includes implementations on Manifoldron: Direct Space Partition via Manifo

dayang_wang 4 Apr 28, 2022
A PyTorch Implementation of "Neural Arithmetic Logic Units"

Neural Arithmetic Logic Units [WIP] This is a PyTorch implementation of Neural Arithmetic Logic Units by Andrew Trask, Felix Hill, Scott Reed, Jack Ra

Kevin Zakka 181 Nov 18, 2022
An implementation of Deep Graph Infomax (DGI) in PyTorch

DGI Deep Graph Infomax (Veličković et al., ICLR 2019): https://arxiv.org/abs/1809.10341 Overview Here we provide an implementation of Deep Graph Infom

Petar Veličković 491 Jan 03, 2023
ACV is a python library that provides explanations for any machine learning model or data.

ACV is a python library that provides explanations for any machine learning model or data. It gives local rule-based explanations for any model or data and different Shapley Values for tree-based mod

Salim Amoukou 85 Dec 27, 2022
A Python multilingual toolkit for Sentiment Analysis and Social NLP tasks

pysentimiento: A Python toolkit for Sentiment Analysis and Social NLP tasks A Transformer-based library for SocialNLP classification tasks. Currently

298 Jan 07, 2023
Repository for the electrical and ICT benchmark model developed in the ERIGrid 2.0 project.

Benchmark Model Electrical and ICT System This repository contains the documentation, code, and models for the electrical and ICT benchmark model deve

ERIGrid 2.0 1 Nov 29, 2021
Semi-Supervised 3D Hand-Object Poses Estimation with Interactions in Time

Semi Hand-Object Semi-Supervised 3D Hand-Object Poses Estimation with Interactions in Time (CVPR 2021).

96 Dec 27, 2022
Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip)

Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip) Introduction TL;DR: We propose an efficient and trainabl

17 Dec 01, 2022
Practical and Real-world applications of ML based on the homework of Hung-yi Lee Machine Learning Course 2021

Machine Learning Theory and Application Overview This repository is inspired by the Hung-yi Lee Machine Learning Course 2021. In that course, professo

SilenceJiang 35 Nov 22, 2022
P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks

P-tuning v2 P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks An optimized prompt tuning strategy for sma

THUDM 540 Dec 30, 2022