The Codebase for Causal Distillation for Language Models.

Overview

Python 3.7 License CC BY-NC

Causal Distillation for Language Models

Zhengxuan Wu*,Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman

The is an implementation of our preprint Causal Distillation for Language Models. The standard approach to distillation trains a student model against two objectives: a task-specific objective (e.g., language modeling) and an imitation objective that encourages the hidden states of the student model to be similar to those of the larger teacher model. In this paper, we show that it is beneficial to augment distillation with a third objective that encourages the student to imitate the causal computation process of the teacher through interchange intervention training (IIT).

We fork our main codebase from the Huggingface Distillation Interface.

Release Notes

12/02/2021 Our paper on Interchange Intervention Training (IIT) is released! Read this more formal definition of the method.
12/06/2021 Released the causal distillation codebase with the preprint.
12/06/2021 Released evaluation results on distilled tiny-BERT (3 layers) with the Wiki-Text 103M dataset.
⬜️ Released evaluation results on causal-distilled tiny-BERT (3 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released evaluation results on causal-distilled BERT (6 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released more ablation studies.
⬜️ Released causal-distilled tiny-BERT (3 layers) model files.
⬜️ Released causal-distilled BERT (6 layers) model files.

If you experience any issues or have suggestions, please contact me either thourgh the issues page or at [email protected].

Benchmark Results

Here are the results on the dev sets of GLUE:

Model Average-score CoLA MNLI MRPC QNLI QQP RTE SST-2 STS-B WNLI
DistilBERT (3 layers) 67.81 22.8 71.6 78.2 82.1 84.3 55.4 86.5 56.7 24.2
CausalBERT (3 layers) 69.71 25.0 72.9 78.6 83.1 84.9 55.4 86.9 66.5 21.5

1 Average-score computed without WNLI.

Main Contents

Citation

If you use this repository, please cite the following two papers: paper for interchange intervention training, and paper for the our distillation method.

  @article{geiger-etal-2021-iit,
        title={Inducing Causal Structure for Interpretable Neural Networks}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
        year={2021},
        eprint={2112.00826},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

  @article{wu-etal-2021-distill,
        title={Causal Distillation for Language Models}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
        year={2021},
        eprint={2112.02505},
        archivePrefix={arXiv},
        primaryClass={cs.CL}
  }

Requirements

  • Python 3.6 or 3.7 are supported.
  • Pytorch Version: 1.9.0
  • Transfermers Version: 4.11.3
  • Datasets Version: Version: 1.8.0
  • We have performed experiments on Titan V GPU. We assume 12GB of GPU memory (more memory can expedite training).
  • Since we build our codebase off the Huggingface Distillation Interface, please review their doc for requirements.

Dataset

Following the Huggingface Distillation Interface, we need to pre-process the datasets before we do distillation. You can refer to their repo for details. We adapt their pre-processing scripts, and update with a few improvements. For example, we can now binarize datasets from the Dataset Hub from huggingface directly.

# preprocessing from disk
python script/binarized_data.py \
--file_path ../../bert-mid-tuning/data-files/wikitext-15M \
--split train \
--field_name text \
--max_parsing_example 1000 \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file ./data/binarized_text

# preprocessing from huggingface.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext+bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

# helper scripts to combine two binarized data files
python scripts/data_combinator.py \
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle \
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle \
--split train \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text

# multiprocessing preprocessor.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/ \
--fast_process \
--preprocessing_num_workers 48

After you get the datasets ready, you need to generate token counts as well.

python scripts/token_counts.py \
--data_file data/binarized_text.train.bert-base-uncased.pickle \
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle \
--vocab_size 30522

Distillation

Before training, we recommand you to initialize your student model with weights extracted from the teacher model.

python scripts/extract_distilbert.py \
--model_type bert \
--model_name bert-base-uncased \
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--num_layers 3

Now, here is an example for you to distill with our causal distillation objective or without,

CUDA_VISIBLE_DEVICES=9,4 python causal_train.py \
--force \
--n_gpu 2 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal 0.25 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 50 \
--n_epoch 3 \
--batch_size 5

CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py \
--force \
--n_gpu 4 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --alpha_causal 0.00 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 124 \
--n_epoch 6 \
--batch_size 4

Note that you can simply turn our causal distillation objective on/off through setting the arguments.

Evaluation

After you get your distilled models, you need to fine-tune them and evaluate them with downstream tasks. We provide you all the scripts you need to run.

MLM Evaluation

CUDA_VISIBLE_DEVICES=5 python run_mlm.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-15M_seed_42_mlm_True_ce_0.25_mlm_0.25_cos_0.25_causal_0.25_nm_single_multilayer/ \
--dataset_dir ../../bert-mid-tuning/data-files/wikitext-15M/ \
--tokenizer_name bert-base-uncased \
--do_eval \
--output_dir /tmp/test-mlm \
--cache_dir ./distill_cache/

GLUE Evaluation

CUDA_VISIBLE_DEVICES=5,7,8,9 python run_glue.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle/ \
--tokenizer_name bert-base-uncased \
--task_name sst2 \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir ./results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

CoNLL Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_ner.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name conll2003 \
--do_train \
--do_eval \
--output_dir ./ner_results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

SQuAD Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_qa.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--save_total_limit 1 \
--output_dir ./qa_results/
This repository contains the code for "Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based Bias in NLP".

Self-Diagnosis and Self-Debiasing This repository contains the source code for Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based

Timo Schick 62 Dec 12, 2022
Tensorflow Implementation of Pixel Transposed Convolutional Networks (PixelTCN and PixelTCL)

Pixel Transposed Convolutional Networks Created by Hongyang Gao, Hao Yuan, Zhengyang Wang and Shuiwang Ji at Texas A&M University. Introduction Pixel

Hongyang Gao 95 Jul 24, 2022
Code for ICLR2018 paper: Improving GAN Training via Binarized Representation Entropy (BRE) Regularization - Y. Cao · W Ding · Y.C. Lui · R. Huang

code for "Improving GAN Training via Binarized Representation Entropy (BRE) Regularization" (ICLR2018 paper) paper: https://arxiv.org/abs/1805.03644 G

21 Oct 12, 2020
Recommendationsystem - Movie-recommendation - matrixfactorization colloborative filtering recommendation system user

recommendationsystem matrixfactorization colloborative filtering recommendation

kunal jagdish madavi 1 Jan 01, 2022
Towards Fine-Grained Reasoning for Fake News Detection

FinerFact This is the PyTorch implementation for the FinerFact model in the AAAI 2022 paper Towards Fine-Grained Reasoning for Fake News Detection (Ar

Ahren_Jin 15 Dec 15, 2022
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
ADOP: Approximate Differentiable One-Pixel Point Rendering

ADOP: Approximate Differentiable One-Pixel Point Rendering Abstract: We present a novel point-based, differentiable neural rendering pipeline for scen

Darius Rückert 1.9k Jan 06, 2023
Fine-tune pretrained Convolutional Neural Networks with PyTorch

Fine-tune pretrained Convolutional Neural Networks with PyTorch. Features Gives access to the most popular CNN architectures pretrained on ImageNet. A

Alex Parinov 694 Nov 23, 2022
AirCode: A Robust Object Encoding Method

AirCode This repo contains source codes for the arXiv preprint "AirCode: A Robust Object Encoding Method" Demo Object matching comparison when the obj

Chen Wang 30 Dec 09, 2022
Python Blood Vessel Topology Analysis

Python Blood Vessel Topology Analysis This repository is not being updated anymore. The new version of PyVesTo is called PyVaNe and is available at ht

6 Nov 15, 2022
OpenMMLab 3D Human Parametric Model Toolbox and Benchmark

Introduction English | 简体中文 MMHuman3D is an open source PyTorch-based codebase for the use of 3D human parametric models in computer vision and comput

OpenMMLab 782 Jan 04, 2023
Algorithms for outlier, adversarial and drift detection

Alibi Detect is an open source Python library focused on outlier, adversarial and drift detection. The package aims to cover both online and offline d

Seldon 1.6k Dec 31, 2022
Head2Toe: Utilizing Intermediate Representations for Better OOD Generalization

Head2Toe: Utilizing Intermediate Representations for Better OOD Generalization Code for reproducing our results in the Head2Toe paper. Paper: arxiv.or

Google Research 62 Dec 12, 2022
[NeurIPS 2021] Official implementation of paper "Learning to Simulate Self-driven Particles System with Coordinated Policy Optimization".

Code for Coordinated Policy Optimization Webpage | Code | Paper | Talk (English) | Talk (Chinese) Hi there! This is the source code of the paper “Lear

DeciForce: Crossroads of Machine Perception and Autonomy 81 Dec 19, 2022
These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations"

Few-shot-NLEs These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations". You can find the smal

Yordan Yordanov 0 Oct 21, 2022
TransMorph: Transformer for Medical Image Registration

TransMorph: Transformer for Medical Image Registration keywords: Vision Transformer, Swin Transformer, convolutional neural networks, image registrati

Junyu Chen 180 Jan 07, 2023
E-Ink Magic Calendar that automatically syncs to Google Calendar and runs off a battery powered Raspberry Pi Zero

MagInkCal This repo contains the code needed to drive an E-Ink Magic Calendar that uses a battery powered (PiSugar2) Raspberry Pi Zero WH to retrieve

2.8k Dec 28, 2022
Azion the best solution of Edge Computing in the world.

Azion Edge Function docker action Create or update an Edge Functions on Azion Edge Nodes. The domain name is the key for decision to a create or updat

8 Jul 16, 2022
CONetV2: Efficient Auto-Channel Size Optimization for CNNs

CONetV2: Efficient Auto-Channel Size Optimization for CNNs Exciting News! CONetV2: Efficient Auto-Channel Size Optimization for CNNs has been accepted

Mahdi S. Hosseini 3 Dec 13, 2021