The Codebase for Causal Distillation for Language Models.

Overview

Python 3.7 License CC BY-NC

Causal Distillation for Language Models

Zhengxuan Wu*,Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman

The is an implementation of our preprint Causal Distillation for Language Models. The standard approach to distillation trains a student model against two objectives: a task-specific objective (e.g., language modeling) and an imitation objective that encourages the hidden states of the student model to be similar to those of the larger teacher model. In this paper, we show that it is beneficial to augment distillation with a third objective that encourages the student to imitate the causal computation process of the teacher through interchange intervention training (IIT).

We fork our main codebase from the Huggingface Distillation Interface.

Release Notes

12/02/2021 Our paper on Interchange Intervention Training (IIT) is released! Read this more formal definition of the method.
12/06/2021 Released the causal distillation codebase with the preprint.
12/06/2021 Released evaluation results on distilled tiny-BERT (3 layers) with the Wiki-Text 103M dataset.
⬜️ Released evaluation results on causal-distilled tiny-BERT (3 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released evaluation results on causal-distilled BERT (6 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released more ablation studies.
⬜️ Released causal-distilled tiny-BERT (3 layers) model files.
⬜️ Released causal-distilled BERT (6 layers) model files.

If you experience any issues or have suggestions, please contact me either thourgh the issues page or at [email protected].

Benchmark Results

Here are the results on the dev sets of GLUE:

Model Average-score CoLA MNLI MRPC QNLI QQP RTE SST-2 STS-B WNLI
DistilBERT (3 layers) 67.81 22.8 71.6 78.2 82.1 84.3 55.4 86.5 56.7 24.2
CausalBERT (3 layers) 69.71 25.0 72.9 78.6 83.1 84.9 55.4 86.9 66.5 21.5

1 Average-score computed without WNLI.

Main Contents

Citation

If you use this repository, please cite the following two papers: paper for interchange intervention training, and paper for the our distillation method.

  @article{geiger-etal-2021-iit,
        title={Inducing Causal Structure for Interpretable Neural Networks}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
        year={2021},
        eprint={2112.00826},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

  @article{wu-etal-2021-distill,
        title={Causal Distillation for Language Models}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
        year={2021},
        eprint={2112.02505},
        archivePrefix={arXiv},
        primaryClass={cs.CL}
  }

Requirements

  • Python 3.6 or 3.7 are supported.
  • Pytorch Version: 1.9.0
  • Transfermers Version: 4.11.3
  • Datasets Version: Version: 1.8.0
  • We have performed experiments on Titan V GPU. We assume 12GB of GPU memory (more memory can expedite training).
  • Since we build our codebase off the Huggingface Distillation Interface, please review their doc for requirements.

Dataset

Following the Huggingface Distillation Interface, we need to pre-process the datasets before we do distillation. You can refer to their repo for details. We adapt their pre-processing scripts, and update with a few improvements. For example, we can now binarize datasets from the Dataset Hub from huggingface directly.

# preprocessing from disk
python script/binarized_data.py \
--file_path ../../bert-mid-tuning/data-files/wikitext-15M \
--split train \
--field_name text \
--max_parsing_example 1000 \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file ./data/binarized_text

# preprocessing from huggingface.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext+bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

# helper scripts to combine two binarized data files
python scripts/data_combinator.py \
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle \
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle \
--split train \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text

# multiprocessing preprocessor.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/ \
--fast_process \
--preprocessing_num_workers 48

After you get the datasets ready, you need to generate token counts as well.

python scripts/token_counts.py \
--data_file data/binarized_text.train.bert-base-uncased.pickle \
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle \
--vocab_size 30522

Distillation

Before training, we recommand you to initialize your student model with weights extracted from the teacher model.

python scripts/extract_distilbert.py \
--model_type bert \
--model_name bert-base-uncased \
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--num_layers 3

Now, here is an example for you to distill with our causal distillation objective or without,

CUDA_VISIBLE_DEVICES=9,4 python causal_train.py \
--force \
--n_gpu 2 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal 0.25 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 50 \
--n_epoch 3 \
--batch_size 5

CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py \
--force \
--n_gpu 4 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --alpha_causal 0.00 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 124 \
--n_epoch 6 \
--batch_size 4

Note that you can simply turn our causal distillation objective on/off through setting the arguments.

Evaluation

After you get your distilled models, you need to fine-tune them and evaluate them with downstream tasks. We provide you all the scripts you need to run.

MLM Evaluation

CUDA_VISIBLE_DEVICES=5 python run_mlm.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-15M_seed_42_mlm_True_ce_0.25_mlm_0.25_cos_0.25_causal_0.25_nm_single_multilayer/ \
--dataset_dir ../../bert-mid-tuning/data-files/wikitext-15M/ \
--tokenizer_name bert-base-uncased \
--do_eval \
--output_dir /tmp/test-mlm \
--cache_dir ./distill_cache/

GLUE Evaluation

CUDA_VISIBLE_DEVICES=5,7,8,9 python run_glue.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle/ \
--tokenizer_name bert-base-uncased \
--task_name sst2 \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir ./results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

CoNLL Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_ner.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name conll2003 \
--do_train \
--do_eval \
--output_dir ./ner_results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

SQuAD Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_qa.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--save_total_limit 1 \
--output_dir ./qa_results/
Code and model benchmarks for "SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology"

NeurIPS 2020 SEVIR Code for paper: SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology Requirement

USAF - MIT Artificial Intelligence Accelerator 46 Dec 15, 2022
Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

THUNLP 5 Jun 16, 2022
InsCLR: Improving Instance Retrieval with Self-Supervision

InsCLR: Improving Instance Retrieval with Self-Supervision This is an official PyTorch implementation of the InsCLR paper. Download Dataset Dataset Im

Zelu Deng 25 Aug 30, 2022
Changing the Mind of Transformers for Topically-Controllable Language Generation

We will first introduce the how to run the IPython notebook demo by downloading our pretrained models. Then, we will introduce how to run our training and evaluation code.

IESL 20 Dec 06, 2022
Winning solution of the Indoor Location & Navigation Kaggle competition

This repository contains the code to generate the winning solution of the Kaggle competition on indoor location and navigation organized by Microsoft

Tom Van de Wiele 62 Dec 28, 2022
A Nim frontend for pytorch, aiming to be mostly auto-generated and internally using ATen.

Master Release Pytorch - Py + Nim A Nim frontend for pytorch, aiming to be mostly auto-generated and internally using ATen. Because Nim compiles to C+

Giovanni Petrantoni 425 Dec 22, 2022
Deep Reinforcement Learning for Keras.

Deep Reinforcement Learning for Keras What is it? keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seaml

Keras-RL 0 Dec 15, 2022
Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset

Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset This repository provides a unified online platform, LoLi-P

Chongyi Li 457 Jan 03, 2023
TensorFlow Ranking is a library for Learning-to-Rank (LTR) techniques on the TensorFlow platform

TensorFlow Ranking is a library for Learning-to-Rank (LTR) techniques on the TensorFlow platform

2.6k Jan 04, 2023
FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective

FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective Official implementation of "FL-WBC: Enhan

Jingwei Sun 26 Nov 28, 2022
Repository of our paper 'Refer-it-in-RGBD' in CVPR 2021

Refer-it-in-RGBD This is the repository of our paper 'Refer-it-in-RGBD: A Bottom-up Approach for 3D Visual Grounding in RGBD Images' in CVPR 2021 Pape

Haolin Liu 34 Nov 07, 2022
Remote sensing change detection tool based on PaddlePaddle

PdRSCD PdRSCD(PaddlePaddle Remote Sensing Change Detection)是一个基于飞桨PaddlePaddle的遥感变化检测的项目,pypi包名为ppcd。目前0.2版本,最新支持图像列表输入的训练和预测,如多期影像、多源影像甚至多期多源影像。可以快速完

38 Aug 31, 2022
StyleGAN2 Webtoon / Anime Style Toonify

StyleGAN2 Webtoon / Anime Style Toonify Korea Webtoon or Japanese Anime Character Stylegan2 base high Quality 1024x1024 / 512x512 Generate and Transfe

121 Dec 21, 2022
Implementation of Online Label Smoothing in PyTorch

Online Label Smoothing Pytorch implementation of Online Label Smoothing (OLS) presented in Delving Deep into Label Smoothing. Introduction As the abst

83 Dec 14, 2022
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

49 Nov 23, 2022
Galileo library for large scale graph training by JD

近年来,图计算在搜索、推荐和风控等场景中获得显著的效果,但也面临超大规模异构图训练,与现有的深度学习框架Tensorflow和PyTorch结合等难题。 Galileo(伽利略)是一个图深度学习框架,具备超大规模、易使用、易扩展、高性能、双后端等优点,旨在解决超大规模图算法在工业级场景的落地难题,提

JD Galileo Team 128 Nov 29, 2022
TLXZoo - Pre-trained models based on TensorLayerX

Pre-trained models based on TensorLayerX. TensorLayerX is a multi-backend AI fra

TensorLayer Community 13 Dec 07, 2022
Pytorch implementation of the popular Improv RNN model originally proposed by the Magenta team.

Pytorch Implementation of Improv RNN Overview This code is a pytorch implementation of the popular Improv RNN model originally implemented by the Mage

Sebastian Murgul 3 Nov 11, 2022
Python binding for Khiva library.

Khiva-Python Build Documentation Build Linux and Mac OS Build Windows Code Coverage README This is the Khiva Python binding, it allows the usage of Kh

Shapelets 46 Oct 16, 2022
Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning

Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning This is the official repository of "Camera Distortion-

Hanbyel Cho 12 Oct 06, 2022