A Closer Look at Structured Pruning for Neural Network Compression

Overview

A Closer Look at Structured Pruning for Neural Network Compression

Code used to reproduce experiments in https://arxiv.org/abs/1810.04622.

To prune, we fill our networks with custom MaskBlocks, which are manipulated using Pruner in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can.

Setup

This is best done in a clean conda environment:

conda create -n prunes python=3.6
conda activate prunes
conda install pytorch torchvision -c pytorch

Repository layout

-train.py: contains all of the code for training large models from scratch and for training pruned models from scratch
-prune.py: contains the code for pruning trained models
-funcs.py: contains useful pruning functions and any functions we used commonly

CIFAR Experiments

First, you will need some initial models.

To train a WRN-40-2:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res'

The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters:

python train.py --net='dense' --depth=100 --data_loc= --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1

These will automatically save checkpoints to the checkpoints folder.

Pruning

Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets)

python prune.py --net='res'   --data_loc= --base_model='res' --save_file='res_fisher'
python prune.py --net='res'   --data_loc= --l1_prune=True --base_model='res' --save_file='res_l1'

python prune.py --net='dense' --depth 100 --data_loc= --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600
python prune.py --net='dense' --depth 100 --data_loc= --l1_prune=True --base_model='dense' --save_file='dense_l1'  --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64  --no_epochs 2600

Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it.
Once finished, we can train the pruned models from scratch, e.g.:

python train.py --data_loc= --net='res' --base_file='res_fisher__prunes' --deploy --mask=1 --save_file='res_fisher__prunes_scratch'

Each model can then be evaluated using:

python train.py --deploy --eval --data_loc= --net='res' --mask=1 --base_file='res_fisher__prunes'

Training Reduced models

This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option:

python train.py --net='res' --depth= --width= --data_loc= --save_file='res_reduced'

To add bottlenecks, use the following:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res_bottle' --bottle --bottle_mult 

With DenseNets you can modify the depth or growth, or use --bottle --bottle_mult as above.

Acknowledgements

Jack Turner wrote the L1 stuff, and some other stuff for that matter.

Code has been liberally borrowed from many a repo, including, but not limited to:

https://github.com/xternalz/WideResNet-pytorch
https://github.com/bamos/densenet.pytorch
https://github.com/kuangliu/pytorch-cifar
https://github.com/ShichenLiu/CondenseNet

Citing this work

If you would like to cite this work, please use the following bibtex entry:

@article{crowley2018pruning,
  title={A Closer Look at Structured Pruning for Neural Network Compression},
  author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael},
  journal={arXiv preprint arXiv:1810.04622},
  year={2018},
  }
Owner
Bayesian and Neural Systems Group
Machine learning research group @ University of Edinburgh
Bayesian and Neural Systems Group
This repo contains the source code and a benchmark for predicting user's utilities with Machine Learning techniques for Computational Persuasion

Machine Learning for Argument-Based Computational Persuasion This repo contains the source code and a benchmark for predicting user's utilities with M

Ivan Donadello 4 Nov 07, 2022
Predicting Price of house by considering ,house age, Distance from public transport

House-Price-Prediction Predicting Price of house by considering ,house age, Distance from public transport, No of convenient stores around house etc..

Musab Jaleel 1 Jan 08, 2022
PyTorch implementation of "Learning to Discover Cross-Domain Relations with Generative Adversarial Networks"

DiscoGAN in PyTorch PyTorch implementation of Learning to Discover Cross-Domain Relations with Generative Adversarial Networks. * All samples in READM

Taehoon Kim 1k Jan 04, 2023
Transformers based fully on MLPs

Awesome MLP-based Transformers papers An up-to-date list of Transformers based fully on MLPs without attention! Why this repo? After transformers and

Fawaz Sammani 35 Dec 30, 2022
Implementation of the bachelor's thesis "Real-time stock predictions with deep learning and news scraping".

Real-time stock predictions with deep learning and news scraping This repository contains a partial implementation of my bachelor's thesis "Real-time

David Álvarez de la Torre 0 Feb 09, 2022
BankNote-Net: Open dataset and encoder model for assistive currency recognition

BankNote-Net: Open Dataset for Assistive Currency Recognition Millions of people around the world have low or no vision. Assistive software applicatio

Microsoft 13 Oct 28, 2022
Create UIs for prototyping your machine learning model in 3 minutes

Note: We just launched Hosted, where anyone can upload their interface for permanent hosting. Check it out! Welcome to Gradio Quickly create customiza

Gradio 11.7k Jan 07, 2023
Official PyTorch code for WACV 2022 paper "CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows"

CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows WACV 2022 preprint:https://arxiv.org/abs/2107.1

Denis 156 Dec 28, 2022
[NAACL & ACL 2021] SapBERT: Self-alignment pretraining for BERT.

SapBERT: Self-alignment pretraining for BERT This repo holds code for the SapBERT model presented in our NAACL 2021 paper: Self-Alignment Pretraining

Cambridge Language Technology Lab 104 Dec 07, 2022
ScaleNet: A Shallow Architecture for Scale Estimation

ScaleNet: A Shallow Architecture for Scale Estimation Repository for the code of ScaleNet paper: "ScaleNet: A Shallow Architecture for Scale Estimatio

Axel Barroso 34 Nov 09, 2022
Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples / ICLR 2018

Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples This project is for the paper "Training Confidence-Calibrated Clas

168 Nov 29, 2022
SimulLR - PyTorch Implementation of SimulLR

PyTorch Implementation of SimulLR There is an interesting work[1] about simultan

11 Dec 22, 2022
Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021)

UNITE and UNITE+ Unbalanced Feature Transport for Exemplar-based Image Translation (CVPR 2021) Unbalanced Intrinsic Feature Transport for Exemplar-bas

Fangneng Zhan 183 Nov 09, 2022
Code for "Human Pose Regression with Residual Log-likelihood Estimation", ICCV 2021 Oral

Human Pose Regression with Residual Log-likelihood Estimation [Paper] [arXiv] [Project Page] Human Pose Regression with Residual Log-likelihood Estima

JeffLi 347 Dec 24, 2022
Code of 3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces

3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces Installation After cloning the repo open

37 Dec 03, 2022
Official Code for "Constrained Mean Shift Using Distant Yet Related Neighbors for Representation Learning"

CMSF Official Code for "Constrained Mean Shift Using Distant Yet Related Neighbors for Representation Learning" Requirements Python = 3.7.6 PyTorch

4 Nov 25, 2022
Self-Regulated Learning for Egocentric Video Activity Anticipation

Self-Regulated Learning for Egocentric Video Activity Anticipation Introduction This is a Pytorch implementation of the model described in our paper:

qzhb 13 Sep 23, 2022
Pyramid Scene Parsing Network, CVPR2017.

Pyramid Scene Parsing Network by Hengshuang Zhao, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, Jiaya Jia, details are in project page. Introduction This

Hengshuang Zhao 1.5k Jan 05, 2023
Deep High-Resolution Representation Learning for Human Pose Estimation

Deep High-Resolution Representation Learning for Human Pose Estimation (accepted to CVPR2019) News If you are interested in internship or research pos

HRNet 167 Dec 27, 2022
A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers.

ViTGAN: Training GANs with Vision Transformers A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers. Refer

Hong-Jia Chen 127 Dec 23, 2022