XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

Overview

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks

ACL 2020 Microsoft Research [Paper] [Video]

Releasing [XtremeDistilTransformers] with Tensorflow 2.3 and HuggingFace Transformers with an unified API with the following features:

  • Distil any supported pre-trained language models as teachers (e.g, Bert, Electra, Roberta)
  • Initialize student model with any pre-trained model (e.g, MiniLM, DistilBert, TinyBert), or initialize from scratch
  • Multilingual text classification and sequence tagging
  • Distil multiple hidden states from teacher
  • Distil deep attention networks from teacher
  • Pairwise and instance-level classification tasks (e.g, MNLI, MRPC, SST)
  • Progressive knowledge transfer with gradual unfreezing
  • Fast mixed precision training for distillation (e.g, mixed_float16, mixed_bfloat16)
  • ONNX runtime inference

Install requirements pip install -r requirements.txt

Initialize XtremeDistilTransformer with (6/384 pre-trained checkpoint)[https://huggingface.co/microsoft/xtremedistil-l6-h384-uncased] or [TinyBERT] (4/312 pre-trained checkpoint)

Sample usages for distilling different pre-trained language models (tested with Python 3.6.9 and CUDA 10.2)

Training

Sequence Labeling for Wiki NER

PYTHONHASHSEED=42 python run_xtreme_distil.py 
--task $$PT_DATA_DIR/datasets/NER 
--model_dir $$PT_OUTPUT_DIR 
--seq_len 32  
--transfer_file $$PT_DATA_DIR/datasets/NER/unlabeled.txt 
--do_NER 
--pt_teacher TFBertModel 
--pt_teacher_checkpoint bert-base-multilingual-cased 
--student_distil_batch_size 256 
--student_ft_batch_size 32
--teacher_batch_size 128  
--pt_student_checkpoint microsoft/xtremedistil-l6-h384-uncased 
--distil_chunk_size 10000 
--teacher_model_dir $$PT_OUTPUT_DIR 
--distil_multi_hidden_states 
--distil_attention 
--compress_word_embedding 
--freeze_word_embedding
--opt_policy mixed_float16

Text Classification for MNLI

PYTHONHASHSEED=42 python run_xtreme_distil.py 
--task $$PT_DATA_DIR/glue_data/MNLI 
--model_dir $$PT_OUTPUT_DIR 
--seq_len 128  
--transfer_file $$PT_DATA_DIR/glue_data/MNLI/train.tsv 
--do_pairwise 
--pt_teacher TFElectraModel 
--pt_teacher_checkpoint google/electra-base-discriminator 
--student_distil_batch_size 128  
--student_ft_batch_size 32
--pt_student_checkpoint microsoft/xtremedistil-l6-h384-uncased 
--teacher_model_dir $$PT_OUTPUT_DIR 
--teacher_batch_size 32
--distil_chunk_size 300000
--opt_policy mixed_float16

Alternatively, use TinyBert pre-trained student model checkpoint as --pt_student_checkpoint nreimers/TinyBERT_L-4_H-312_v2

Arguments


- task folder contains
	-- train/dev/test '.tsv' files with text and classification labels / token-wise tags (space-separated)
	--- Example 1: feel good about themselves <tab> 1
	--- Example 2: '' Atelocentra '' Meyrick , 1884 <tab> O B-LOC O O O O
	-- label files containing class labels for sequence labeling
	-- transfer file containing unlabeled data
	
- model_dir to store/restore model checkpoints

- task arguments
-- do_pairwise for pairwise classification tasks like MNLI and MRPC
-- do_NER for sequence labeling

- teacher arguments
-- pt_teacher for teacher model to distil (e.g., TFBertModel, TFRobertaModel, TFElectraModel)
-- pt_teacher_checkpoint for pre-trained teacher model checkpoints (e.g., bert-base-multilingual-cased, roberta-large, google/electra-base-discriminator)

- student arguments
-- pt_student_checkpoint to initialize from pre-trained small student models (e.g., MiniLM, DistilBert, TinyBert)
-- instead of pre-trained checkpoint, initialize a raw student from scratch with
--- hidden_size
--- num_hidden_layers
--- num_attention_heads

- distillation features
-- distil_multi_hidden_states to distil multiple hidden states from the teacher
-- distil_attention to distil deep attention network of the teacher
-- compress_word_embedding to initialize student word embedding with SVD-compressed teacher word embedding (useful for multilingual distillation)
-- freeze_word_embedding to keep student word embeddings frozen during distillation (useful for multilingual distillation)
-- opt_policy (e.g., mixed_float16 for GPU and mixed_bfloat16 for TPU)
-- distil_chunk_size for using transfer data in chunks during distillation (reduce for OOM issues, checkpoints are saved after every distil_chunk_size steps)

Model Outputs

The above training code generates intermediate model checkpoints to continue the training in case of abrupt termination instead of starting from scratch -- all saved in $$PT_OUTPUT_DIR. The final output of the model consists of (i) xtremedistil.h5 with distilled model weights, (ii) xtremedistil-config.json with the training configuration, and (iii) word_embedding.npy for the input word embeddings from the student model.

Prediction

PYTHONHASHSEED=42 python run_xtreme_distil_predict.py 
--do_eval 
--model_dir $$PT_OUTPUT_DIR 
--do_predict 
--pred_file ../../datasets/NER/unlabeled.txt
--opt_policy mixed_float16

*ONNX Runtime Inference

You can also use ONXX Runtime for inference speedup with the following script:

PYTHONHASHSEED=42 python run_xtreme_distil_predict_onnx.py 
--do_eval 
--model_dir $$PT_OUTPUT_DIR 
--do_predict 
--pred_file ../../datasets/NER/unlabeled.txt

For details on ONNX Runtime Inference, environment and arguments refer to this Notebook The script is for online inference with batch_size=1.

*Continued Fine-tuning

You can continue fine-tuning the distilled/compressed student model on more labeled data with the following script:

PYTHONHASHSEED=42 python run_xtreme_distil_ft.py --model_dir $$PT_OUTPUT_DIR 

If you use this code, please cite:

@inproceedings{mukherjee-hassan-awadallah-2020-xtremedistil,
    title = "{X}treme{D}istil: Multi-stage Distillation for Massive Multilingual Models",
    author = "Mukherjee, Subhabrata  and
      Hassan Awadallah, Ahmed",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month = jul,
    year = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-main.202",
    pages = "2221--2234",
    abstract = "Deep and large pre-trained language models are the state-of-the-art for various natural language processing tasks. However, the huge size of these models could be a deterrent to using them in practice. Some recent works use knowledge distillation to compress these huge models into shallow ones. In this work we study knowledge distillation with a focus on multilingual Named Entity Recognition (NER). In particular, we study several distillation strategies and propose a stage-wise optimization scheme leveraging teacher internal representations, that is agnostic of teacher architecture, and show that it outperforms strategies employed in prior works. Additionally, we investigate the role of several factors like the amount of unlabeled data, annotation resources, model architecture and inference latency to name a few. We show that our approach leads to massive compression of teacher models like mBERT by upto 35x in terms of parameters and 51x in terms of latency for batch inference while retaining 95{\%} of its F1-score for NER over 41 languages.",
}

Code is released under MIT license.

Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

75 Dec 02, 2022
NeuralTalk is a Python+numpy project for learning Multimodal Recurrent Neural Networks that describe images with sentences.

#NeuralTalk Warning: Deprecated. Hi there, this code is now quite old and inefficient, and now deprecated. I am leaving it on Github for educational p

Andrej 5.3k Jan 07, 2023
Official implementation for paper Render In-between: Motion Guided Video Synthesis for Action Interpolation

Render In-between: Motion Guided Video Synthesis for Action Interpolation [Paper] [Supp] [arXiv] [4min Video] This is the official Pytorch implementat

8 Oct 27, 2022
This repository contains the map content ontology used in narrative cartography

Narrative-cartography-ontology This repository contains the map content ontology used in narrative cartography, which is associated with a submission

Weiming Huang 0 Oct 31, 2021
Continuous Query Decomposition for Complex Query Answering in Incomplete Knowledge Graphs

Continuous Query Decomposition This repository contains the official implementation for our ICLR 2021 (Oral) paper, Complex Query Answering with Neura

UCL Natural Language Processing 71 Dec 29, 2022
ParaGen is a PyTorch deep learning framework for parallel sequence generation

ParaGen is a PyTorch deep learning framework for parallel sequence generation. Apart from sequence generation, ParaGen also enhances various NLP tasks, including sequence-level classification, extrac

Bytedance Inc. 169 Dec 22, 2022
Dynamic Realtime Animation Control

Our project is targeted at making an application that dynamically detects the user’s expressions and gestures and projects it onto an animation software which then renders a 2D/3D animation realtime

Harsh Avinash 10 Aug 01, 2022
Identifying Stroke Indicators Using Rough Sets

Identifying Stroke Indicators Using Rough Sets With the spirit of reproducible research, this repository contains all the codes required to produce th

Muhammad Salman Pathan 0 Jun 09, 2022
Invasive Plant Species Identification

Invasive_Plant_Species_Identification Used LiDAR Odometry and Mapping (LOAM) to create a 3D point cloud map which can be used to identify invasive pla

2 May 12, 2022
Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Mingrui Yu 3 Jan 07, 2022
Simple SN-GAN to generate CryptoPunks

CryptoPunks GAN Simple SN-GAN to generate CryptoPunks. Neural network architecture and training code has been modified from the PyTorch DCGAN example.

Teddy Koker 66 Dec 15, 2022
Code for 'Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning', ICCV 2021

CMIC-Retrieval Code for Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning. ICCV 2021. Introduction In this wo

42 Nov 17, 2022
The code for two papers: Feedback Transformer and Expire-Span.

transformer-sequential This repo contains the code for two papers: Feedback Transformer Expire-Span The training code is structured for long sequentia

Facebook Research 125 Dec 25, 2022
ARAE-Tensorflow for Discrete Sequences (Adversarially Regularized Autoencoder)

ARAE Tensorflow Code Code for the paper Adversarially Regularized Autoencoders for Generating Discrete Structures by Zhao, Kim, Zhang, Rush and LeCun

19 Nov 12, 2021
(CVPR 2022) Pytorch implementation of "Self-supervised transformers for unsupervised object discovery using normalized cut"

(CVPR 2022) TokenCut Pytorch implementation of Tokencut: Self-supervised Transformers for Unsupervised Object Discovery using Normalized Cut Yangtao W

YANGTAO WANG 200 Jan 02, 2023
A simple interface for editing natural photos with generative neural networks.

Neural Photo Editor A simple interface for editing natural photos with generative neural networks. This repository contains code for the paper "Neural

Andy Brock 2.1k Dec 29, 2022
This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

4 Aug 02, 2022
Single Red Blood Cell Hydrodynamic Traps Via the Generative Design

Rbc-traps-generative-design - The generative design for single red clood cell hydrodynamic traps using GEFEST framework

Natural Systems Simulation Lab 4 Jun 16, 2022
TF Image Segmentation: Image Segmentation framework

TF Image Segmentation: Image Segmentation framework The aim of the TF Image Segmentation framework is to provide/provide a simplified way for: Convert

Daniil Pakhomov 546 Dec 17, 2022
Official PyTorch implementation of RIO

Image-Level or Object-Level? A Tale of Two Resampling Strategies for Long-Tailed Detection Figure 1: Our proposed Resampling at image-level and obect-

NVIDIA Research Projects 17 May 20, 2022