Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Overview

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding

PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) paper.

Method Description

Distilled Sentence Embedding (DSE) distills knowledge from a finetuned state-of-the-art transformer model (BERT) to create high quality sentence embeddings. For a complete description, as well as implementation details and hyperparameters, please refer to the paper.

Usage

Follow the instructions below in order to run the training procedure of the Distilled Sentence Embedding (DSE) method. The python scripts below can be run with the -h parameter to get more information.

1. Install Requirements

Tested with Python 3.6+.

pip install -r requirements.txt

2. Download GLUE Datasets

Run the download_glue_data.py script to download the GLUE datasets.

python download_glue_data.py

3. Finetune BERT on a Specific Task

Finetune a standard BERT model on a specific task (e.g., MRPC, MNLI, etc.). Below is an example for the MRPC dataset.

python finetune_bert.py \
--bert_model bert-large-uncased-whole-word-masking \
--task_name mrpc \
--do_train \
--do_eval \
--do_lower_case \
--data_dir glue_data/MRPC \
--max_seq_length 128 \
--train_batch_size 32 \
--gradient_accumulation_steps 2 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir outputs/large_uncased_finetuned_mrpc \
--overwrite_output_dir \
--no_parallel

Note: For our code to work with the AllNLI dataset (a combination of the MNLI and SNLI datasets), you simply need to create a folder where the downloaded GLUE datasets are and copy the MNLI and SNLI datasets into it.

4. Create Logits for Distillation from the Finetuned BERT

Execute the following command to create the logits which will be used for the distillation training objective. Note that the bert_checkpoint_dir parameter has to match the output_dir from the previous command.

python run_distillation_logits_creator.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--do_lower_case \
--train_features_path glue_data/MRPC/train_bert-large-uncased-whole-word-masking_128_mrpc \
--bert_checkpoint_dir outputs/large_uncased_finetuned_mrpc

5. Train the DSE Model using the Finetuned BERT Logits

Train the DSE model using the extracted logits. Notice that the distillation_logits_path parameter needs to be changed according to the task.

python dse_train_runner.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--distillation_logits_path outputs/logits/large_uncased_finetuned_mrpc_logits.pt \
--do_lower_case \
--file_log \
--epochs 8 \
--store_checkpoints \
--fc_dims 512 1

Important Notes:

  • To store checkpoints for the model make sure that the --store_checkpoints flag is passed as shown above.
  • The fc_dims parameter accepts a list of space separated integers, and is the dimensions of the fully connected classifier that is put on top of the extracted features from the Siamese DSE model. The output dimension (in this case 1) needs to be changed according to the wanted output dimensionality. For example, for the MNLI dataset the fc_dims parameter should be 512 3 since it is a 3 class classification task.

6. Loading the Trained DSE Model

During training, checkpoints of the Trainer object which contains the model will be saved. You can load these checkpoints and extract the model state dictionary from them. Then you can load the state into a created DSESiameseClassifier model. The load_dse_checkpoint_example.py script contains an example of how to do that.

To load the model trained with the example commands above you can use:

python load_dse_checkpoint_example.py \
--task_name mrpc \
--trainer_checkpoint <path_to_saved_checkpoint> \
--do_lower_case \
--fc_dims 512 1

Acknowledgments

Citation

If you find this code useful, please cite the following paper:

@inproceedings{barkan2020scalable,
  title={Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding},
  author={Barkan, Oren and Razin, Noam and Malkiel, Itzik and Katz, Ori and Caciularu, Avi and Koenigstein, Noam},
  booktitle={AAAI Conference on Artificial Intelligence (AAAI)},
  year={2020}
}
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
TCPNet - Temporal-attentive-Covariance-Pooling-Networks-for-Video-Recognition

Temporal-attentive-Covariance-Pooling-Networks-for-Video-Recognition This is an implementation of TCPNet. Introduction For video recognition task, a g

Zilin Gao 21 Dec 08, 2022
Code for the paper "Training GANs with Stronger Augmentations via Contrastive Discriminator" (ICLR 2021)

Training GANs with Stronger Augmentations via Contrastive Discriminator (ICLR 2021) This repository contains the code for reproducing the paper: Train

Jongheon Jeong 174 Dec 29, 2022
MetaTTE: a Meta-Learning Based Travel Time Estimation Model for Multi-city Scenarios

MetaTTE: a Meta-Learning Based Travel Time Estimation Model for Multi-city Scenarios This is the official TensorFlow implementation of MetaTTE in the

morningstarwang 4 Dec 14, 2022
B2EA: An Evolutionary Algorithm Assisted by Two Bayesian Optimization Modules for Neural Architecture Search

B2EA: An Evolutionary Algorithm Assisted by Two Bayesian Optimization Modules for Neural Architecture Search This is the offical implementation of the

SNU ADSL 0 Feb 07, 2022
My freqtrade strategies

My freqtrade-strategies Hi there! This is repo for my freqtrade-strategies. My name is Ilya Zelenchuk, I'm a lecturer at the SPbU university (https://

171 Dec 05, 2022
A Pytorch Implementation of ClariNet

ClariNet A Pytorch Implementation of ClariNet (Mel Spectrogram -- Waveform) Requirements PyTorch 0.4.1 & python 3.6 & Librosa Examples Step 1. Downlo

Sungwon Kim 286 Sep 15, 2022
SNE-RoadSeg in PyTorch, ECCV 2020

SNE-RoadSeg Introduction This is the official PyTorch implementation of SNE-RoadSeg: Incorporating Surface Normal Information into Semantic Segmentati

242 Dec 20, 2022
pytorch implementation of ABC : Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning

ABC:Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning, NeurIPS 2021 pytorch implementation of ABC : Auxiliary Balanced Class

Hyuck Lee 25 Dec 22, 2022
A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation

##A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation. #USAGE To run the trained classifier on some images: python w

Alex Seewald 13 Nov 17, 2022
Point cloud processing tool library.

Point Cloud ToolBox This point cloud processing tool library can be used to process point clouds, 3d meshes, and voxels. Environment python 3.7.5 Dep

ZhangXinyun 40 Dec 09, 2022
This is RFA-Toolbox, a simple and easy-to-use library that allows you to optimize your neural network architectures using receptive field analysis (RFA) and create graph visualizations of your architecture.

ReceptiveFieldAnalysisToolbox This is RFA-Toolbox, a simple and easy-to-use library that allows you to optimize your neural network architectures usin

84 Nov 23, 2022
Code for paper " AdderNet: Do We Really Need Multiplications in Deep Learning?"

AdderNet: Do We Really Need Multiplications in Deep Learning? This code is a demo of CVPR 2020 paper AdderNet: Do We Really Need Multiplications in De

HUAWEI Noah's Ark Lab 915 Jan 01, 2023
Our CIKM21 Paper "Incorporating Query Reformulating Behavior into Web Search Evaluation"

Reformulation-Aware-Metrics Introduction This codebase contains source-code of the Python-based implementation of our CIKM 2021 paper. Chen, Jia, et a

xuanyuan14 5 Mar 05, 2022
LSUN Dataset Documentation and Demo Code

LSUN Please check LSUN webpage for more information about the dataset. Data Release All the images in one category are stored in one lmdb database fil

Fisher Yu 426 Jan 02, 2023
Algorithmic trading with deep learning experiments

Deep-Trading Algorithmic trading with deep learning experiments. Now released part one - simple time series forecasting. I plan to implement more soph

Alex Honchar 1.4k Jan 02, 2023
QSYM: A Practical Concolic Execution Engine Tailored for Hybrid Fuzzing

QSYM: A Practical Concolic Execution Engine Tailored for Hybrid Fuzzing Environment Tested on Ubuntu 14.04 64bit and 16.04 64bit Installation # disabl

gts3.org (<a href=[email protected])"> 581 Dec 30, 2022
Unsupervised Learning of Video Representations using LSTMs

Unsupervised Learning of Video Representations using LSTMs Code for paper Unsupervised Learning of Video Representations using LSTMs by Nitish Srivast

Elman Mansimov 341 Dec 20, 2022
Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations

ODE GAN (Prototype) in PyTorch Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary

Somshubra Majumdar 15 Feb 10, 2022
Source code of generalized shuffled linear regression

Generalized-Shuffled-Linear-Regression Code for the ICCV 2021 paper: Generalized Shuffled Linear Regression. Authors: Feiran Li, Kent Fujiwara, Fumio

FEI 7 Oct 26, 2022
Code for NeurIPS 2021 paper 'Spatio-Temporal Variational Gaussian Processes'

Spatio-Temporal Variational GPs This repository is the official implementation of the methods in the publication: O. Hamelijnck, W.J. Wilkinson, N.A.

AaltoML 26 Sep 16, 2022