Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Overview

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding

PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) paper.

Method Description

Distilled Sentence Embedding (DSE) distills knowledge from a finetuned state-of-the-art transformer model (BERT) to create high quality sentence embeddings. For a complete description, as well as implementation details and hyperparameters, please refer to the paper.

Usage

Follow the instructions below in order to run the training procedure of the Distilled Sentence Embedding (DSE) method. The python scripts below can be run with the -h parameter to get more information.

1. Install Requirements

Tested with Python 3.6+.

pip install -r requirements.txt

2. Download GLUE Datasets

Run the download_glue_data.py script to download the GLUE datasets.

python download_glue_data.py

3. Finetune BERT on a Specific Task

Finetune a standard BERT model on a specific task (e.g., MRPC, MNLI, etc.). Below is an example for the MRPC dataset.

python finetune_bert.py \
--bert_model bert-large-uncased-whole-word-masking \
--task_name mrpc \
--do_train \
--do_eval \
--do_lower_case \
--data_dir glue_data/MRPC \
--max_seq_length 128 \
--train_batch_size 32 \
--gradient_accumulation_steps 2 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir outputs/large_uncased_finetuned_mrpc \
--overwrite_output_dir \
--no_parallel

Note: For our code to work with the AllNLI dataset (a combination of the MNLI and SNLI datasets), you simply need to create a folder where the downloaded GLUE datasets are and copy the MNLI and SNLI datasets into it.

4. Create Logits for Distillation from the Finetuned BERT

Execute the following command to create the logits which will be used for the distillation training objective. Note that the bert_checkpoint_dir parameter has to match the output_dir from the previous command.

python run_distillation_logits_creator.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--do_lower_case \
--train_features_path glue_data/MRPC/train_bert-large-uncased-whole-word-masking_128_mrpc \
--bert_checkpoint_dir outputs/large_uncased_finetuned_mrpc

5. Train the DSE Model using the Finetuned BERT Logits

Train the DSE model using the extracted logits. Notice that the distillation_logits_path parameter needs to be changed according to the task.

python dse_train_runner.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--distillation_logits_path outputs/logits/large_uncased_finetuned_mrpc_logits.pt \
--do_lower_case \
--file_log \
--epochs 8 \
--store_checkpoints \
--fc_dims 512 1

Important Notes:

  • To store checkpoints for the model make sure that the --store_checkpoints flag is passed as shown above.
  • The fc_dims parameter accepts a list of space separated integers, and is the dimensions of the fully connected classifier that is put on top of the extracted features from the Siamese DSE model. The output dimension (in this case 1) needs to be changed according to the wanted output dimensionality. For example, for the MNLI dataset the fc_dims parameter should be 512 3 since it is a 3 class classification task.

6. Loading the Trained DSE Model

During training, checkpoints of the Trainer object which contains the model will be saved. You can load these checkpoints and extract the model state dictionary from them. Then you can load the state into a created DSESiameseClassifier model. The load_dse_checkpoint_example.py script contains an example of how to do that.

To load the model trained with the example commands above you can use:

python load_dse_checkpoint_example.py \
--task_name mrpc \
--trainer_checkpoint <path_to_saved_checkpoint> \
--do_lower_case \
--fc_dims 512 1

Acknowledgments

Citation

If you find this code useful, please cite the following paper:

@inproceedings{barkan2020scalable,
  title={Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding},
  author={Barkan, Oren and Razin, Noam and Malkiel, Itzik and Katz, Ori and Caciularu, Avi and Koenigstein, Noam},
  booktitle={AAAI Conference on Artificial Intelligence (AAAI)},
  year={2020}
}
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Official PyTorch implementation of BlobGAN: Spatially Disentangled Scene Representations

BlobGAN: Spatially Disentangled Scene Representations Official PyTorch Implementation Paper | Project Page | Video | Interactive Demo BlobGAN.mp4 This

148 Dec 29, 2022
A series of Python scripts to access measurements from Fluke 28X meters. Fluke IR Remote Interface required.

Fluke289_data_access A series of Python scripts to access measurements from Fluke 28X meters. Fluke IR Remote Interface required. Created from informa

3 Dec 08, 2022
HairCLIP: Design Your Hair by Text and Reference Image

Overview This repository hosts the official PyTorch implementation of the paper: "HairCLIP: Design Your Hair by Text and Reference Image". Our single

322 Jan 06, 2023
General purpose Slater-Koster tight-binding code for electronic structure calculations

tight-binder Introduction General purpose tight-binding code for electronic structure calculations based on the Slater-Koster approximation. The code

9 Dec 15, 2022
Image classification for projects and researches

This is a tool to help you quickly solve classification problems including: data analysis, training, report results and model explanation.

Nguyễn Trường Lâu 2 Dec 27, 2021
Import Python modules from dicts and JSON formatted documents.

Paker Paker is module for importing Python packages/modules from dictionaries and JSON formatted documents. It was inspired by httpimporter. Important

Wojciech Wentland 1 Sep 07, 2022
Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval (NeurIPS'21)

Baleen Baleen is a state-of-the-art model for multi-hop reasoning, enabling scalable multi-hop search over massive collections for knowledge-intensive

Stanford Future Data Systems 22 Dec 05, 2022
Black-Box-Tuning - Black-Box Tuning for Language-Model-as-a-Service

Black-Box-Tuning Source code for paper "Black-Box Tuning for Language-Model-as-a

Tianxiang Sun 149 Jan 04, 2023
PyTorch implementation for Convolutional Networks with Adaptive Inference Graphs

Convolutional Networks with Adaptive Inference Graphs (ConvNet-AIG) This repository contains a PyTorch implementation of the paper Convolutional Netwo

Andreas Veit 176 Dec 07, 2022
Tilted Empirical Risk Minimization (ICLR '21)

Tilted Empirical Risk Minimization This repository contains the implementation for the paper Tilted Empirical Risk Minimization ICLR 2021 Empirical ri

Tian Li 40 Nov 28, 2022
Rayvens makes it possible for data scientists to access hundreds of data services within Ray with little effort.

Rayvens augments Ray with events. With Rayvens, Ray applications can subscribe to event streams, process and produce events. Rayvens leverages Apache

CodeFlare 32 Dec 25, 2022
An Approach to Explore Logistic Regression Models

User-centered Regression An Approach to Explore Logistic Regression Models This tool applies the potential of Attribute-RadViz in identifying correlat

0 Nov 12, 2021
Weight initialization schemes for PyTorch nn.Modules

nninit Weight initialization schemes for PyTorch nn.Modules. This is a port of the popular nninit for Torch7 by @kaixhin. ##Update This repo has been

Alykhan Tejani 69 Jan 26, 2021
RoMa: A lightweight library to deal with 3D rotations in PyTorch.

RoMa: A lightweight library to deal with 3D rotations in PyTorch. RoMa (which stands for Rotation Manipulation) provides differentiable mappings betwe

NAVER 90 Dec 27, 2022
Official implementation of SynthTIGER (Synthetic Text Image GEneratoR) ICDAR 2021

🐯 SynthTIGER: Synthetic Text Image GEneratoR Official implementation of SynthTIGER | Paper | Datasets Moonbin Yim1, Yoonsik Kim1, Han-cheol Cho1, Sun

Clova AI Research 256 Jan 05, 2023
[CVPR 2020] Local Class-Specific and Global Image-Level Generative Adversarial Networks for Semantic-Guided Scene Generation

Contents Local and Global GAN Cross-View Image Translation Semantic Image Synthesis Acknowledgments Related Projects Citation Contributions Collaborat

Hao Tang 131 Dec 07, 2022
Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning

Harmonious Textual Layout Generation over Natural Images via Deep Aesthetics Learning Code for the paper Harmonious Textual Layout Generation over Nat

7 Aug 09, 2022
Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks.

The Lottery Ticket Hypothesis for Pre-trained BERT Networks Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks. [NeurIPS

VITA 122 Dec 14, 2022
code for "Self-supervised edge features for improved Graph Neural Network training",

Self-supervised edge features for improved Graph Neural Network training Data availability: Here is a link to the raw data for the organoids dataset.

Neal Ravindra 23 Dec 02, 2022
potpourri3d - An invigorating blend of 3D geometry tools in Python.

A Python library of various algorithms and utilities for 3D triangle meshes and point clouds. Managed by Nicholas Sharp, with new tools added lazily as needed. Currently, mainly bindings to C++ tools

Nicholas Sharp 295 Jan 05, 2023