Source code for NAACL 2021 paper "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference"

Related tags

Deep LearningTR-BERT
Overview

TR-BERT

Source code and dataset for "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference".

model

The code is based on huggaface's transformers. Thanks to them! We will release all the source code in the future.

Requirement

Install dependencies and apex:

pip3 install -r requirement.txt
pip3 install --editable transformers

Pretrained models

Download the DistilBERT-3layer and BERT-1024 from Google Drive/Tsinghua Cloud.

Classfication

Download the IMDB, Yelp, 20News datasets from Google Drive/Tsinghua Cloud.

Download the Hyperpartisan dataset, and randomly split it into train/dev/test set: python3 split_hyperpartisan.py

Train BERT/DistilBERT Model

Use flag --do train:

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path bert-base-uncased --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 16 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 5  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval  --evaluate_during_training  --do_train

where task_name can be set as imdb/yelp_f/20news/hyperpartisan for different tasks and model type can be set as bert/distilbert for different models.

Compute Graident for Residual Strategy

Use flag --do_eval_grad.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_grad

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the policy network solely

Start from the checkpoint from the task-specific fine-tuned model. Change model_type from bert to autobert, and run with flag --do_train --train_rl:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1  --do_lower_case  --do_train --train_rl --alpha 1 --guide_rate 0.5

where alpha is the harmonic coefficient for the length punishment and guide_rate is the proportion of imitation learning steps. model_type can be set as autobert/distilautobert for applying token reduction to BERT/DistilBERT.

Compute Logits for Knowledge Distilation

Use flag --do_eval_logits.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_logits

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the whole network with both the task-specifc objective and RL objective

Start from the checkpoint from --train_rl model and run with flag --do_train --train_both --train_teacher:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1 --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_train --train_both --train_teacher --alpha 1

Evaluate

Use flag --do_eval:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1_both  --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_eval --eval_all_checkpoints

When the batch size is more than 1 in evaluating, we will remain the same number of tokens for each instance in the same batch.

Initialize

For IMDB dataset, we find that when we directly initialize the selector with heuristic objective before train the policy network solely, we can get a bit better performance. For other datasets, this step makes little change. Run this step with flag --do_train --train_init:

python3 trans_imdb_rank.py
python3 run_classification.py  --task_name imdb  --model_type initbert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/bert_init  --do_lower_case  --do_train --train_init 

Question Answering

Download the SQuAD 2.0 dataset.

Download the MRQA dataset with our split] from Google Drive/Tsinghua Cloud.

Download the HotpotQA dataset from the Transformer-XH repository, where paragraphs are retrieved for each question according to TF-IDF, entity linking and hyperlink and re-ranked by BERT re-ranker.

Download the TriviaQA dataset, where paragraphs are re-rank by the linear passage re-ranker in DocQA.

Download the WikiHop dataset.

The whole training progress of question answer models is similiar to text classfication models, with flags --do_train, --do_train --train_rl, --do_train --train_both --train_teacher in turn. The codes of each dataset:

SQuAD: run_squad.py with flag version_2_with_negative

NewsQA / NaturalQA: run_mrqa.py

RACE: run_race_classify.py

HotpotQA: run_hotpotqa.py

TriviaQA: run_triviaqa.py

WikiHop: run_wikihop.py

Harmonic Coefficient Lambda

The example harmonic coefficients are shown as follows:

Dataset train_rl train_both
SQuAD 2.0 5 5
NewsQA 3 5
NaturalQA 2 2
RACE 0.5 0.1
YELP.F 2 0.5
20News 1 1
IMDB 1 1
HotpotQA 0.1 4
TriviaQA 0.5 1
Hyperparisan 0.01 0.01

Cite

If you use the code, please cite this paper:

@inproceedings{ye2021trbert,
  title={TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference},
  author={Deming Ye, Yankai Lin, Yufei Huang, Maosong Sun},
  booktitle={Proceedings of NAACL 2021},
  year={2021}
}
Owner
THUNLP
Natural Language Processing Lab at Tsinghua University
THUNLP
A minimal implementation of Gaussian process regression in PyTorch

pytorch-minimal-gaussian-process In search of truth, simplicity is needed. There exist heavy-weighted libraries, but as you know, we need to go bare b

Sangwoong Yoon 38 Nov 25, 2022
Official implementation for paper: Feature-Style Encoder for Style-Based GAN Inversion

Feature-Style Encoder for Style-Based GAN Inversion Official implementation for paper: Feature-Style Encoder for Style-Based GAN Inversion. Code will

InterDigital 63 Jan 03, 2023
Code for Robust Contrastive Learning against Noisy Views

Robust Contrastive Learning against Noisy Views This repository provides a PyTorch implementation of the Robust InfoNCE loss proposed in paper Robust

Ching-Yao Chuang 53 Jan 08, 2023
[AAAI22] Reliable Propagation-Correction Modulation for Video Object Segmentation

Reliable Propagation-Correction Modulation for Video Object Segmentation (AAAI22) Preview version paper of this work is available at: https://arxiv.or

Xiaohao Xu 70 Dec 04, 2022
Matplotlib Image labeller for classifying images

mpl-image-labeller Use Matplotlib to label images for classification. Works anywhere Matplotlib does - from the notebook to a standalone gui! For more

Ian Hunt-Isaak 5 Sep 24, 2022
A Fast Monotone Rotating Shallow Water model

pyRSW A Fast Monotone Rotating Shallow Water model How fast? As fast as a sustained 2 Gflop/s per core on a 2.5 GHz cpu (or 2048 Gflop/s with 1024 cor

Guillaume Roullet 13 Sep 28, 2022
This repository will be a summary and outlook on all our open, medical, AI advancements.

medical by LAION This repository will be a summary and outlook on all our open, medical, AI advancements. See the medical-general channel in the medic

LAION AI 18 Dec 30, 2022
FFTNet vocoder implementation

Unofficial Implementation of FFTNet vocode paper. implement the model. implement tests. overfit on a single batch (sanity check). linearize weights fo

Eren Gölge 81 Dec 08, 2022
[ICCV 2021] Code release for "Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks"

Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks By Yikai Wang, Yi Yang, Fuchun Sun, Anbang Yao. This is the pytorc

Yikai Wang 26 Nov 20, 2022
Implementation of PersonaGPT Dialog Model

PersonaGPT An open-domain conversational agent with many personalities PersonaGPT is an open-domain conversational agent cpable of decoding personaliz

ILLIDAN Lab 42 Jan 01, 2023
A time series processing library

Timeseria Timeseria is a time series processing library which aims at making it easy to handle time series data and to build statistical and machine l

Stefano Alberto Russo 11 Aug 08, 2022
Implementation of momentum^2 teacher

Momentum^2 Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning Requirements All experiments are done with python3.6, torch

jemmy li 121 Sep 26, 2022
Finetuning Pipeline

KLUE Baseline Korean(한국어) KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark. See our paper fo

74 Dec 13, 2022
Official implementation of EfficientPose

EfficientPose This is the official implementation of EfficientPose. We based our work on the Keras EfficientDet implementation xuannianz/EfficientDet

2 May 17, 2022
(CVPR 2022 Oral) Official implementation for "Surface Representation for Point Clouds"

RepSurf - Surface Representation for Point Clouds [CVPR 2022 Oral] By Haoxi Ran* , Jun Liu, Chengjie Wang ( * : corresponding contact) The pytorch off

Haoxi Ran 264 Dec 23, 2022
9th place solution in "Santa 2020 - The Candy Cane Contest"

Santa 2020 - The Candy Cane Contest My solution in this Kaggle competition "Santa 2020 - The Candy Cane Contest", 9th place. Basic Strategy In this co

toshi_k 22 Nov 26, 2021
git《Tangent Space Backpropogation for 3D Transformation Groups》(CVPR 2021) GitHub:1]

LieTorch: Tangent Space Backpropagation Introduction The LieTorch library generalizes PyTorch to 3D transformation groups. Just as torch.Tensor is a m

Princeton Vision & Learning Lab 482 Jan 06, 2023
Code for "Single-view robot pose and joint angle estimation via render & compare", CVPR 2021 (Oral).

Single-view robot pose and joint angle estimation via render & compare Yann Labbé, Justin Carpentier, Mathieu Aubry, Josef Sivic CVPR: Conference on C

Yann Labbé 51 Oct 14, 2022
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 03, 2023
SCAAML is a deep learning framwork dedicated to side-channel attacks run on top of TensorFlow 2.x.

SCAAML (Side Channel Attacks Assisted with Machine Learning) is a deep learning framwork dedicated to side-channel attacks. It is written in python and run on top of TensorFlow 2.x.

Google 69 Dec 21, 2022