Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks.

Overview

The Lottery Ticket Hypothesis for Pre-trained BERT Networks

License: MIT

Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks. [NeurIPS 2020]

Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Zhangyang Wang, Michael Carbin.

Our implementation is based on Huggingface repo. Details are referred to README here. Pre-trained subnetworks are coming soon.

Overview

The Existence of Matching Subnetworks in BERT

Transfer Learning for BERT Winning Tickets

Method

Reproduce Details

Prerequisites and Installation

Details are referred to README here.

Iterative Magnitude Pruning (IMP)

MLM task:

python -u LT_pretrain.py 
	   --output_dir LT_pretrain_model
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --train_data_file pretrain_data/en.train 
	   --do_train 
	   --eval_data_file pretrain_data/en.valid 
	   --do_eval 
	   --per_gpu_train_batch_size 16 
	   --per_gpu_eval_batch_size 16 
	   --evaluate_during_training 
	   --num_train_epochs 1 
	   --logging_steps 10000 
	   --save_steps 10000 
	   --mlm 
	   --overwrite_output_dir 
	   --seed 57

Glue task:

python -u LT_glue.py
	   --output_dir tmp/mnli 
	   --logging_steps 36813 
	   --task_name MNLI 
	   --data_dir glue_data/MNLI 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --do_train 
	   --do_eval 
	   --do_lower_case 
	   --max_seq_length 128 
	   --per_gpu_train_batch_size 32 
	   --learning_rate 2e-5 
	   --num_train_epochs 30 
	   --overwrite_output_dir 
	   --evaluate_during_training 
	   --save_steps 36813
	   --eval_all_checkpoints 
	   --seed 57

SQuAD task:

python -u squad_trans.py 
	   --output_dir tmp/530/squad 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
       --do_train 
       --do_eval 
       --do_lower_case 
       --train_file SQuAD/train-v1.1.json 
       --predict_file SQuAD/dev-v1.1.json 
       --per_gpu_train_batch_size 16 
       --learning_rate 3e-5 
       --num_train_epochs 40 
       --max_seq_length 384 
       --doc_stride 128 
       --evaluate_during_training 
       --eval_all_checkpoints 
       --overwrite_output_dir 
       --logging_steps 22000 
       --save_steps 22000 
       --seed 57

One-shot Magnitude Pruning (OMP)

python oneshot.py --weight [pre or rand] --model [glue or squad or pretrain] --rate 0.5

Fine-tuning

MLM task:

python -u pretrain_trans.py 
	   --dir pre\  [using random weight or official pretrain weight]
	   --weight_pertub tmp/shuffle_weight.pt\ [weight for Bert (not required)]
	   --mask_dir tmp/dif_mask/pretrain_mask.pt \ [mask file]
	   --output_dir tmp/530/pre 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --train_data_file pretrain_data/en.train 
	   --do_train --eval_data_file pretrain_data/en.valid 
	   --do_eval 
	   --per_gpu_train_batch_size 8 
	   --per_gpu_eval_batch_size 8 
	   --evaluate_during_training 
	   --num_train_epochs 1 
	   --logging_steps 2000 
	   --save_steps 0 
	   --max_steps 20000  
	   --mlm 
	   --overwrite_output_dir 
	   --seed 57

Glue task:

python -u glue_trans.py 
       --dir pre \  [using random weight or official pretrain weight]
       --weight_pertub tmp/shuffle_weight.pt \ [weight for Bert (not required)]
       --mask_dir tmp/dif_mask/mnli_mask.pt \ [mask file]
       --output_dir tmp/530/mnli 
       --logging_steps 12271 
       --task_name MNLI 
       --data_dir glue_data/MNLI 
       --model_type bert 
       --model_name_or_path bert-base-uncased 
       --do_train 
       --do_eval 
       --do_lower_case 
       --max_seq_length 128 
       --per_gpu_train_batch_size 32 
       --learning_rate 2e-5 
       --num_train_epochs 3 
       --overwrite_output_dir 
       --evaluate_during_training 
       --save_steps 0 
       --eval_all_checkpoints 
       --seed 5

SQuAD task:

python -u squad_trans.py 
	   --dir pre \  [using random weight or official pretrain weight]
	   --weight_pertub tmp/shuffle_weight.pt \ [weight for Bert (not required)]
	   --mask_dir tmp/dif_mask/squad_mask.pt \ [mask file]
	   --output_dir tmp/530/squad 
	   --model_type bert 
	   --model_name_or_path bert-base-uncased 
	   --do_train 
	   --do_eval 
	   --do_lower_case 
	   --train_file SQuAD/train-v1.1.json 
	   --predict_file SQuAD/dev-v1.1.json 
	   --per_gpu_train_batch_size 16 
	   --learning_rate 3e-5 
	   --num_train_epochs 4 
	   --max_seq_length 384 
	   --doc_stride 128 
	   --evaluate_during_training 
	   --eval_all_checkpoints 
	   --overwrite_output_dir 
	   --logging_steps 5500 
	   --save_steps 0 
	   --seed 57

Subnetwork with Ramdomly Suffuled Pre-trined Weight

python pertub_weight.py

Citation

If you use this code for your research, please cite our paper:

@misc{chen2020lottery,
    title={The Lottery Ticket Hypothesis for Pre-trained BERT Networks},
    author={Tianlong Chen and Jonathan Frankle and Shiyu Chang and Sijia Liu and Yang Zhang and Zhangyang Wang and Michael Carbin},
    year={2020},
    eprint={2007.12223},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Acknowlegement

We would like to express our deepest gratitude to the MIT-IBM Watson AI Lab. In particular, we would like to thank John Cohn for his generous help in providing us with the computing resources necessary to conduct this research.

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Semi-supervised Adversarial Learning to Generate Photorealistic Face Images of New Identities from 3D Morphable Model

Semi-supervised Adversarial Learning to Generate Photorealistic Face Images of New Identities from 3D Morphable Model Baris Gecer 1, Binod Bhattarai 1

Baris Gecer 190 Dec 29, 2022
Code for our paper at ECCV 2020: Post-Training Piecewise Linear Quantization for Deep Neural Networks

PWLQ Updates 2020/07/16 - We are working on getting permission from our institution to release our source code. We will release it once we are granted

54 Dec 15, 2022
All of the figures and notebooks for my deep learning book, for free!

"Deep Learning - A Visual Approach" by Andrew Glassner This is the official repo for my book from No Starch Press. Ordering the book My book is called

Andrew Glassner 227 Jan 04, 2023
An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

Luna Yue Huang 41 Oct 29, 2022
Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks

Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks This is the official code for DyReg model inroduced in Discovering Dyna

Bitdefender Machine Learning 11 Nov 08, 2022
a pytorch implementation of auto-punctuation learned character by character

Learning Auto-Punctuation by Reading Engadget Articles Link to Other of my work 🌟 Deep Learning Notes: A collection of my notes going from basic mult

Ge Yang 137 Nov 09, 2022
How to Leverage Multimodal EHR Data for Better Medical Predictions?

How to Leverage Multimodal EHR Data for Better Medical Predictions? This repository contains the code of the paper: How to Leverage Multimodal EHR Dat

13 Dec 13, 2022
[CVPR'22] COAP: Learning Compositional Occupancy of People

COAP: Compositional Articulated Occupancy of People Paper | Video | Project Page This is the official implementation of the CVPR 2022 paper COAP: Lear

Marko Mihajlovic 111 Dec 11, 2022
Code for paper "Context-self contrastive pretraining for crop type semantic segmentation"

Code for paper "Context-self contrastive pretraining for crop type semantic segmentation" Setting up a python environment Follow the instruction in ht

Michael Tarasiou 11 Oct 09, 2022
A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Yingtian Liu 6 Mar 17, 2022
A BaSiC Tool for Background and Shading Correction of Optical Microscopy Images

BaSiC Matlab code accompanying A BaSiC Tool for Background and Shading Correction of Optical Microscopy Images by Tingying Peng, Kurt Thorn, Timm Schr

Marr Lab 34 Dec 18, 2022
Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Database

Python cx_Oracle Notebooks, 2022 The repository contains Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Da

Christopher Jones 13 Dec 15, 2022
Official implementation for the paper "Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection"

Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection PyTorch code release of the paper "Attentive Prototypes for Sour

Deepti Hegde 23 Oct 17, 2022
Architecture Patterns with Python (TDD, DDD, EDM)

architecture-traning Architecture Patterns with Python (TDD, DDD, EDM) Chapter 5. 높은 기어비와 낮은 기어비의 TDD 5.2 도메인 계층 테스트를 서비스 계층으로 옮겨야 하는가? 도메인 계층 테스트 def

minsung sim 2 Mar 04, 2022
Normalization Matters in Weakly Supervised Object Localization (ICCV 2021)

Normalization Matters in Weakly Supervised Object Localization (ICCV 2021) 99% of the code in this repository originates from this link. ICCV 2021 pap

Jeesoo Kim 10 Feb 01, 2022
Using Clinical Drug Representations for Improving Mortality and Length of Stay Predictions

Using Clinical Drug Representations for Improving Mortality and Length of Stay Predictions Usage Clone the code to local. https://github.com/tanlab/MI

Computational Biology and Machine Learning lab @ TOBB ETU 3 Oct 18, 2022
Example of semantic segmentation in Keras

keras-semantic-segmentation-example Example of semantic segmentation in Keras Single class example: Generated data: random ellipse with random color o

53 Mar 23, 2022
CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation

CPT This repository contains code and checkpoints for CPT. CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Gener

fastNLP 341 Dec 29, 2022
A PyTorch Implementation of ViT (Vision Transformer)

ViT - Vision Transformer This is an implementation of ViT - Vision Transformer by Google Research Team through the paper "An Image is Worth 16x16 Word

Quan Nguyen 7 May 11, 2022
A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

jedibobo 3 Dec 28, 2022