Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks.

Overview

The Lottery Ticket Hypothesis for Pre-trained BERT Networks

License: MIT

Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks. [NeurIPS 2020]

Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Zhangyang Wang, Michael Carbin.

Our implementation is based on Huggingface repo. Details are referred to README here. Pre-trained subnetworks are coming soon.

Overview

The Existence of Matching Subnetworks in BERT

Transfer Learning for BERT Winning Tickets

Method

Reproduce Details

Prerequisites and Installation

Details are referred to README here.

Iterative Magnitude Pruning (IMP)

MLM task:

python -u LT_pretrain.py 
	   --output_dir LT_pretrain_model
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --train_data_file pretrain_data/en.train 
	   --do_train 
	   --eval_data_file pretrain_data/en.valid 
	   --do_eval 
	   --per_gpu_train_batch_size 16 
	   --per_gpu_eval_batch_size 16 
	   --evaluate_during_training 
	   --num_train_epochs 1 
	   --logging_steps 10000 
	   --save_steps 10000 
	   --mlm 
	   --overwrite_output_dir 
	   --seed 57

Glue task:

python -u LT_glue.py
	   --output_dir tmp/mnli 
	   --logging_steps 36813 
	   --task_name MNLI 
	   --data_dir glue_data/MNLI 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --do_train 
	   --do_eval 
	   --do_lower_case 
	   --max_seq_length 128 
	   --per_gpu_train_batch_size 32 
	   --learning_rate 2e-5 
	   --num_train_epochs 30 
	   --overwrite_output_dir 
	   --evaluate_during_training 
	   --save_steps 36813
	   --eval_all_checkpoints 
	   --seed 57

SQuAD task:

python -u squad_trans.py 
	   --output_dir tmp/530/squad 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
       --do_train 
       --do_eval 
       --do_lower_case 
       --train_file SQuAD/train-v1.1.json 
       --predict_file SQuAD/dev-v1.1.json 
       --per_gpu_train_batch_size 16 
       --learning_rate 3e-5 
       --num_train_epochs 40 
       --max_seq_length 384 
       --doc_stride 128 
       --evaluate_during_training 
       --eval_all_checkpoints 
       --overwrite_output_dir 
       --logging_steps 22000 
       --save_steps 22000 
       --seed 57

One-shot Magnitude Pruning (OMP)

python oneshot.py --weight [pre or rand] --model [glue or squad or pretrain] --rate 0.5

Fine-tuning

MLM task:

python -u pretrain_trans.py 
	   --dir pre\  [using random weight or official pretrain weight]
	   --weight_pertub tmp/shuffle_weight.pt\ [weight for Bert (not required)]
	   --mask_dir tmp/dif_mask/pretrain_mask.pt \ [mask file]
	   --output_dir tmp/530/pre 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --train_data_file pretrain_data/en.train 
	   --do_train --eval_data_file pretrain_data/en.valid 
	   --do_eval 
	   --per_gpu_train_batch_size 8 
	   --per_gpu_eval_batch_size 8 
	   --evaluate_during_training 
	   --num_train_epochs 1 
	   --logging_steps 2000 
	   --save_steps 0 
	   --max_steps 20000  
	   --mlm 
	   --overwrite_output_dir 
	   --seed 57

Glue task:

python -u glue_trans.py 
       --dir pre \  [using random weight or official pretrain weight]
       --weight_pertub tmp/shuffle_weight.pt \ [weight for Bert (not required)]
       --mask_dir tmp/dif_mask/mnli_mask.pt \ [mask file]
       --output_dir tmp/530/mnli 
       --logging_steps 12271 
       --task_name MNLI 
       --data_dir glue_data/MNLI 
       --model_type bert 
       --model_name_or_path bert-base-uncased 
       --do_train 
       --do_eval 
       --do_lower_case 
       --max_seq_length 128 
       --per_gpu_train_batch_size 32 
       --learning_rate 2e-5 
       --num_train_epochs 3 
       --overwrite_output_dir 
       --evaluate_during_training 
       --save_steps 0 
       --eval_all_checkpoints 
       --seed 5

SQuAD task:

python -u squad_trans.py 
	   --dir pre \  [using random weight or official pretrain weight]
	   --weight_pertub tmp/shuffle_weight.pt \ [weight for Bert (not required)]
	   --mask_dir tmp/dif_mask/squad_mask.pt \ [mask file]
	   --output_dir tmp/530/squad 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --do_train 
	   --do_eval 
	   --do_lower_case 
	   --train_file SQuAD/train-v1.1.json 
	   --predict_file SQuAD/dev-v1.1.json 
	   --per_gpu_train_batch_size 16 
	   --learning_rate 3e-5 
	   --num_train_epochs 4 
	   --max_seq_length 384 
	   --doc_stride 128 
	   --evaluate_during_training 
	   --eval_all_checkpoints 
	   --overwrite_output_dir 
	   --logging_steps 5500 
	   --save_steps 0 
	   --seed 57

Subnetwork with Ramdomly Suffuled Pre-trined Weight

python pertub_weight.py

Citation

If you use this code for your research, please cite our paper:

@misc{chen2020lottery,
    title={The Lottery Ticket Hypothesis for Pre-trained BERT Networks},
    author={Tianlong Chen and Jonathan Frankle and Shiyu Chang and Sijia Liu and Yang Zhang and Zhangyang Wang and Michael Carbin},
    year={2020},
    eprint={2007.12223},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Acknowlegement

We would like to express our deepest gratitude to the MIT-IBM Watson AI Lab. In particular, we would like to thank John Cohn for his generous help in providing us with the computing resources necessary to conduct this research.

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
[ICLR 2021] Is Attention Better Than Matrix Decomposition?

Enjoy-Hamburger 🍔 Official implementation of Hamburger, Is Attention Better Than Matrix Decomposition? (ICLR 2021) Under construction. Introduction T

Gsunshine 271 Dec 29, 2022
potpourri3d - An invigorating blend of 3D geometry tools in Python.

A Python library of various algorithms and utilities for 3D triangle meshes and point clouds. Managed by Nicholas Sharp, with new tools added lazily as needed. Currently, mainly bindings to C++ tools

Nicholas Sharp 295 Jan 05, 2023
Our VMAgent is a platform for exploiting Reinforcement Learning (RL) on Virtual Machine (VM) scheduling tasks.

VMAgent is a platform for exploiting Reinforcement Learning (RL) on Virtual Machine (VM) scheduling tasks. VMAgent is constructed based on one month r

56 Dec 12, 2022
The official implementation of the CVPR 2021 paper FAPIS: a Few-shot Anchor-free Part-based Instance Segmenter

FAPIS The official implementation of the CVPR 2021 paper FAPIS: a Few-shot Anchor-free Part-based Instance Segmenter Introduction This repo is primari

Khoi Nguyen 8 Dec 11, 2022
Code for the paper "A Study of Face Obfuscation in ImageNet"

A Study of Face Obfuscation in ImageNet Code for the paper: A Study of Face Obfuscation in ImageNet Kaiyu Yang, Jacqueline Yau, Li Fei-Fei, Jia Deng,

35 Oct 04, 2022
Multitask Learning Strengthens Adversarial Robustness

Multitask Learning Strengthens Adversarial Robustness

Columbia University 15 Jun 10, 2022
Interactive Image Segmentation via Backpropagating Refinement Scheme

Won-Dong Jang and Chang-Su Kim, Interactive Image Segmentation via Backpropagating Refinement Scheme, CVPR 2019

Won-Dong Jang 85 Sep 15, 2022
This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge column damage detection

Bridge-damage-segmentation This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge c

Jingxiao Liu 5 Dec 07, 2022
YouRefIt: Embodied Reference Understanding with Language and Gesture

YouRefIt: Embodied Reference Understanding with Language and Gesture YouRefIt: Embodied Reference Understanding with Language and Gesture by Yixin Che

16 Jul 11, 2022
Resco: A simple python package that report the effect of deep residual learning

resco Description resco is a simple python package that report the effect of dee

Pierre-Arthur Claudé 1 Jun 28, 2022
CoINN: Correlated-informed neural networks: a new machine learning framework to predict pressure drop in micro-channels

CoINN: Correlated-informed neural networks: a new machine learning framework to predict pressure drop in micro-channels Accurate pressure drop estimat

Alejandro Montanez 0 Jan 21, 2022
Implementation of Pix2Seq in PyTorch

pix2seq-pytorch Implementation of Pix2Seq paper Different from the paper image input size 1280 bin size 1280 LambdaLR scheduler used instead of Linear

Tony Shin 9 Dec 15, 2022
The final project of "Applying AI to 3D Medical Imaging Data" from "AI for Healthcare" nanodegree - Udacity.

Quantifying Hippocampus Volume for Alzheimer's Progression Background Alzheimer's disease (AD) is a progressive neurodegenerative disorder that result

Omar Laham 1 Jan 14, 2022
Constructing Neural Network-Based Models for Simulating Dynamical Systems

Constructing Neural Network-Based Models for Simulating Dynamical Systems Note this repo is work in progress prior to reviewing This is a companion re

Christian Møldrup Legaard 21 Nov 25, 2022
Tutorial repo for an end-to-end Data Science project

End-to-end Data Science project This is the repo with the notebooks, code, and additional material used in the ITI's workshop. The goal of the session

Deena Gergis 127 Dec 30, 2022
DIVeR: Deterministic Integration for Volume Rendering

DIVeR: Deterministic Integration for Volume Rendering This repo contains the training and evaluation code for DIVeR. Setup python 3.8 pytorch 1.9.0 py

64 Dec 27, 2022
Code for paper Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting

Decoupled Spatial-Temporal Graph Neural Networks Code for our paper: Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting.

S22 43 Jan 04, 2023
Analyzing basic network responses to novel classes

novelty-detection Analyzing how AlexNet responds to novel classes with varying degrees of similarity to pretrained classes from ImageNet. If you find

Noam Eshed 34 Oct 02, 2022
A large dataset of 100k Google Satellite and matching Map images, resembling pix2pix's Google Maps dataset.

Larger Google Sat2Map dataset This dataset extends the aerial ⟷ Maps dataset used in pix2pix (Isola et al., CVPR17). The provide script download_sat2m

34 Dec 28, 2022
Implementation of Graph Convolutional Networks in TensorFlow

Graph Convolutional Networks This is a TensorFlow implementation of Graph Convolutional Networks for the task of (semi-supervised) classification of n

Thomas Kipf 6.6k Dec 30, 2022