Block Sparse movement pruning

Overview

Movement Pruning: Adaptive Sparsity by Fine-Tuning

Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; however, it is less effective in the transfer learning regime that has become standard for state-of-the-art natural language processing applications. We propose the use of movement pruning, a simple, deterministic first-order weight pruning method that is more adaptive to pretrained model fine-tuning. Experiments show that when pruning large pretrained language models, movement pruning shows significant improvements in high-sparsity regimes. When combined with distillation, the approach achieves minimal accuracy loss with down to only 3% of the model parameters:

Fine-pruning+Distillation
(Teacher=BERT-base fine-tuned)
BERT base
fine-tuned
Remaining
Weights (%)
Magnitude Pruning L0 Regularization Movement Pruning Soft Movement Pruning
SQuAD - Dev
EM/F1
80.4/88.1 10%
3%
70.2/80.1
45.5/59.6
72.4/81.9
64.3/75.8
75.6/84.3
67.5/78.0
76.6/84.9
72.7/82.3
MNLI - Dev
acc/MM acc
84.5/84.9 10%
3%
78.3/79.3
69.4/70.6
78.7/79.7
76.0/76.2
80.1/80.4
76.5/77.4
81.2/81.8
79.5/80.1
QQP - Dev
acc/F1
91.4/88.4 10%
3%
79.8/65.0
72.4/57.8
88.1/82.8
87.0/81.9
89.7/86.2
86.1/81.5
90.2/86.8
89.1/85.5

This page contains information on how to fine-prune pre-trained models such as BERT to obtain extremely sparse models with movement pruning. In contrast to magnitude pruning which selects weights that are far from 0, movement pruning retains weights that are moving away from 0.

For more information, we invite you to check out our paper. You can also have a look at this fun Explain Like I'm Five introductory slide deck.

Extreme sparsity and efficient storage

One promise of extreme pruning is to obtain extremely small models that can be easily sent (and stored) on edge devices. By setting weights to 0., we reduce the amount of information we need to store, and thus decreasing the memory size. We are able to obtain extremely sparse fine-pruned models with movement pruning: ~95% of the dense performance with ~5% of total remaining weights in the BERT encoder.

In this notebook, we showcase how we can leverage standard tools that exist out-of-the-box to efficiently store an extremely sparse question answering model (only 6% of total remaining weights in the encoder). We are able to reduce the memory size of the encoder from the 340MB (the orignal dense BERT) to 11MB, without any additional training of the model (every operation is performed post fine-pruning). It is sufficiently small to store it on a 91' floppy disk 📎 !

While movement pruning does not directly optimize for memory footprint (but rather the number of non-null weights), we hypothetize that further memory compression ratios can be achieved with specific quantization aware trainings (see for instance Q8BERT, And the Bit Goes Down or Quant-Noise).

Fine-pruned models

As examples, we release two English PruneBERT checkpoints (models fine-pruned from a pre-trained BERT checkpoint), one on SQuAD and the other on MNLI.

  • prunebert-base-uncased-6-finepruned-w-distil-squad
    Pre-trained BERT-base-uncased fine-pruned with soft movement pruning on SQuAD v1.1. We use an additional distillation signal from BERT-base-uncased finetuned on SQuAD. The encoder counts 6% of total non-null weights and reaches 83.8 F1 score. The model can be accessed with: pruned_bert = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")
  • prunebert-base-uncased-6-finepruned-w-distil-mnli
    Pre-trained BERT-base-uncased fine-pruned with soft movement pruning on MNLI. We use an additional distillation signal from BERT-base-uncased finetuned on MNLI. The encoder counts 6% of total non-null weights and reaches 80.7 (matched) accuracy. The model can be accessed with: pruned_bert = BertForSequenceClassification.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-mnli")

How to fine-prune?

Setup

The code relies on the 🤗 Transformers library. In addition to the dependencies listed in the examples folder, you should install a few additional dependencies listed in the requirements.txt file: pip install -r requirements.txt.

Note that we built our experiments on top of a stabilized version of the library (commit https://github.com/huggingface/transformers/commit/352d5472b0c1dec0f420d606d16747d851b4bda8): we do not guarantee that everything is still compatible with the latest version of the master branch.

Fine-pruning with movement pruning

Below, we detail how to reproduce the results reported in the paper. We use SQuAD as a running example. Commands (and scripts) can be easily adapted for other tasks.

The following command fine-prunes a pre-trained BERT-base on SQuAD using movement pruning towards 15% of remaining weights (85% sparsity). Note that we freeze all the embeddings modules (from their pre-trained value) and only prune the Fully Connected layers in the encoder (12 layers of Transformer Block).

SERIALIZATION_DIR=<OUTPUT_DIR>
SQUAD_DATA=squad_data

mkdir $SQUAD_DATA
cd $SQUAD_DATA
wget -q https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
wget -q https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
cd ..


python examples/movement-pruning/masked_run_squad.py \
    --output_dir $SERIALIZATION_DIR \
    --data_dir $SQUAD_DATA \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
    --initial_threshold 1 --final_threshold 0.15 \
    --initial_warmup 1 --final_warmup 2 \
    --pruning_method topK --mask_init constant --mask_scale 0.

Fine-pruning with other methods

We can also explore other fine-pruning methods by changing the pruning_method parameter:

Soft movement pruning

python examples/movement-pruning/masked_run_squad.py \
    --output_dir $SERIALIZATION_DIR \
    --data_dir $SQUAD_DATA \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 --mask_scores_learning_rate 1e-2 \
    --initial_threshold 0 --final_threshold 0.1 \
    --initial_warmup 1 --final_warmup 2 \
    --pruning_method sigmoied_threshold --mask_init constant --mask_scale 0. \
    --regularization l1 --final_lambda 400.

L0 regularization

python examples/movement-pruning/masked_run_squad.py \
    --output_dir $SERIALIZATION_DIR \
    --data_dir $SQUAD_DATA \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 --mask_scores_learning_rate 1e-1 \
    --initial_threshold 1. --final_threshold 1. \
    --initial_warmup 1 --final_warmup 1 \
    --pruning_method l0 --mask_init constant --mask_scale 2.197 \
    --regularization l0 --final_lambda 125.

Iterative Magnitude Pruning

python examples/movement-pruning/masked_run_squad.py \
    --output_dir ./dbg \
    --data_dir examples/distillation/data/squad_data \
    --train_file train-v1.1.json \
    --predict_file dev-v1.1.json \
    --do_train --do_eval --do_lower_case \
    --model_type masked_bert \
    --model_name_or_path bert-base-uncased \
    --per_gpu_train_batch_size 16 \
    --warmup_steps 5400 \
    --num_train_epochs 10 \
    --learning_rate 3e-5 \
    --initial_threshold 1 --final_threshold 0.15 \
    --initial_warmup 1 --final_warmup 2 \
    --pruning_method magnitude

After fine-pruning

Counting parameters

Regularization based pruning methods (soft movement pruning and L0 regularization) rely on the penalty to induce sparsity. The multiplicative coefficient controls the sparsity level. To obtain the effective sparsity level in the encoder, we simply count the number of activated (non-null) weights:

python examples/movement-pruning/counts_parameters.py \
    --pruning_method sigmoied_threshold \
    --threshold 0.1 \
    --serialization_dir $SERIALIZATION_DIR

Pruning once for all

Once the model has been fine-pruned, the pruned weights can be set to 0. once for all (reducing the amount of information to store). In our running experiments, we can convert a MaskedBertForQuestionAnswering (a BERT model augmented to enable on-the-fly pruning capabilities) to a standard BertForQuestionAnswering:

python examples/movement-pruning/bertarize.py \
    --pruning_method sigmoied_threshold \
    --threshold 0.1 \
    --model_name_or_path $SERIALIZATION_DIR

Hyper-parameters

For reproducibility purposes, we share the detailed results presented in the paper. These tables exhaustively describe the individual hyper-parameters used for each data point.

Inference speed

Early experiments show that even though models fine-pruned with (soft) movement pruning are extremely sparse, they do not benefit from significant improvement in terms of inference speed when using the standard PyTorch inference. We are currently benchmarking and exploring inference setups specifically for sparse architectures. In particular, hardware manufacturers are announcing devices that will speedup inference for sparse networks considerably.

Citation

If you find this resource useful, please consider citing the following paper:

@article{sanh2020movement,
    title={Movement Pruning: Adaptive Sparsity by Fine-Tuning},
    author={Victor Sanh and Thomas Wolf and Alexander M. Rush},
    year={2020},
    eprint={2005.07683},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
Owner
Hugging Face
Solving NLP, one commit at a time!
Hugging Face
Cowsay - A rewrite of cowsay in python

Python Cowsay A rewrite of cowsay in python. Allows for parsing of existing .cow

James Ansley 3 Jun 27, 2022
Convert Table data to approximate values with GUI

Table_Editor Convert Table data to approximate values with GUIs... usage - Import methods for extension Tables. Imported method supposed to have only

CLJ 1 Jan 10, 2022
Data Consistency for Magnetic Resonance Imaging

Data Consistency for Magnetic Resonance Imaging Data Consistency (DC) is crucial for generalization in multi-modal MRI data and robustness in detectin

Dimitris Karkalousos 19 Dec 12, 2022
Official codes: Self-Supervised Learning by Estimating Twin Class Distribution

TWIST: Self-Supervised Learning by Estimating Twin Class Distributions Codes and pretrained models for TWIST: @article{wang2021self, title={Self-Sup

Bytedance Inc. 85 Dec 15, 2022
Segment axon and myelin from microscopy data using deep learning

Segment axon and myelin from microscopy data using deep learning. Written in Python. Using the TensorFlow framework. Based on a convolutional neural network architecture. Pixels are classified as eit

NeuroPoly 103 Nov 29, 2022
Structure-Preserving Deraining with Residue Channel Prior Guidance (ICCV2021)

SPDNet Structure-Preserving Deraining with Residue Channel Prior Guidance (ICCV2021) Requirements Linux Platform NVIDIA GPU + CUDA CuDNN PyTorch == 0.

41 Dec 12, 2022
Supervised domain-agnostic prediction framework for probabilistic modelling

A supervised domain-agnostic framework that allows for probabilistic modelling, namely the prediction of probability distributions for individual data

The Alan Turing Institute 112 Oct 23, 2022
Pytorch implementation of FlowNet by Dosovitskiy et al.

FlowNetPytorch Pytorch implementation of FlowNet by Dosovitskiy et al. This repository is a torch implementation of FlowNet, by Alexey Dosovitskiy et

Clément Pinard 762 Jan 02, 2023
9th place solution

AllDataAreExt-Galixir-Kaggle-HPA-2021-Solution Team Members Qishen Ha is Master of Engineering from the University of Tokyo. Machine Learning Engineer

daishu 5 Nov 18, 2021
Pytorch implementation for the paper: Contrastive Learning for Cold-start Recommendation

Contrastive Learning for Cold-start Recommendation This is our Pytorch implementation for the paper: Yinwei Wei, Xiang Wang, Qi Li, Liqiang Nie, Yan L

45 Dec 13, 2022
Keyword spotting on Arm Cortex-M Microcontrollers

Keyword spotting for Microcontrollers This repository consists of the tensorflow models and training scripts used in the paper: Hello Edge: Keyword sp

Arm Software 1k Dec 30, 2022
Accurate Phylogenetic Inference with Symmetry-Preserving Neural Networks

Accurate Phylogenetic Inference with a Symmetry-preserving Neural Network Model Claudia Solis-Lemus Shengwen Yang Leonardo Zepeda-Núñez This repositor

Leonardo Zepeda-Núñez 2 Feb 11, 2022
以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的斗地主ai

ddz-ai 介绍 斗地主是一种扑克游戏。游戏最少由3个玩家进行,用一副54张牌(连鬼牌),其中一方为地主,其余两家为另一方,双方对战,先出完牌的一方获胜。 ddz-ai以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的系统,使其经过大量训练后,能在实际游戏中获

freefuiiismyname 88 May 15, 2022
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 01, 2023
A Tensorfflow implementation of Attend, Infer, Repeat

Attend, Infer, Repeat: Fast Scene Understanding with Generative Models This is an unofficial Tensorflow implementation of Attend, Infear, Repeat (AIR)

Adam Kosiorek 82 May 27, 2022
Semantic Segmentation of images using PixelLib with help of Pascalvoc dataset trained with Deeplabv3+ framework.

CARscan- Approach 1 - Segmentation of images by detecting contours. It failed because in images with elements along with cars were also getting detect

Padmanabha Banerjee 5 Jul 29, 2021
I explore rock vs. mine prediction using a SONAR dataset

I explore rock vs. mine prediction using a SONAR dataset. Using a Logistic Regression Model for my prediction algorithm, I intend on predicting what an object is based on supervised learning.

Jeff Shen 1 Jan 11, 2022
Google Landmark Recogntion and Retrieval 2021 Solutions

Google Landmark Recogntion and Retrieval 2021 Solutions In this repository you can find solution and code for Google Landmark Recognition 2021 and Goo

Vadim Timakin 5 Nov 25, 2022
A Multi-modal Model Chinese Spell Checker Released on ACL2021.

ReaLiSe ReaLiSe is a multi-modal Chinese spell checking model. This the office code for the paper Read, Listen, and See: Leveraging Multimodal Informa

DaDa 106 Dec 29, 2022
Bridging Vision and Language Model

BriVL BriVL (Bridging Vision and Language Model) 是首个中文通用图文多模态大规模预训练模型。BriVL模型在图文检索任务上有着优异的效果,超过了同期其他常见的多模态预训练模型(例如UNITER、CLIP)。 BriVL论文:WenLan: Bridgi

235 Dec 27, 2022