Text Generation by Learning from Demonstrations

Overview

Text Generation by Learning from Demonstrations

The README was last updated on March 7, 2021. The repo is based on fairseq (v0.9.?).

Paper

arXiv

Prerequisites

Per fairseq usage, we need to install this particular modifed version fairseq. The simplest way: pip install --editable ./.

Due to pytorch changes, and given that we're using a slightly older version of fairseq (see below), please use pytorch version <= 1.6.0. However, the GOLD algorithm can be easily implemented on top of the latest fairseq (or most text generation codebases).

Datasets

For downloading CNN/DM and XSum datasets, we follow the instructions here; note that this link does not correspond to the latest fairseq. Our version of the CNN/DM input articles include the prepended "(CNN)" tags. For downloading IWSLT14 De-En dataset, we follow the instructions here. The binary files are provided in our repo, in the directory data-bin. For downloading the particular version of our NQG dataset, we follow the instructions here. The binary files are provided upon request.

Code: experiments on transformer models using fairseq

For reproducibility, the code is based on a April 2020 version of fairseq (based on release v0.9.0). However, it is easy to reimplement the GOLD algorithm in the latest version of fairseq and in another frameworks.

How to implement in the latest version of fairseq?

  • If your GPUs "have large memory", then most of the implementation happens around the criterion code (for question generation, summarization, translation, the py file is ./fairseq/criterions/label_smoothed_cross_entropy.py in the April 2020 version of fairseq). Note that the implementation in this repo uses this approach.
    • "Have large memory": Meaning the GPUs can store pi, pi-tilde, p_MLE at the same time; see Algorithm 1 in the paper. In our experiments (using the same datasets, same batch size, etc.), this would imply that the GPUs have ~24G of memory.
  • If your GPUs cannot fit the above models, then you may need to input p_MLE probabilities as features. This can be done by first saving the probabilities into a text file or pickle file, and then loading them in the load_langpair_dataset function of ./fairseq/tasks/translation.py (or other corresponding files for other tasks).

How to implement in other codebase?

  • See Algorithm 1 in the paper. The majority of the work will happen around the loss computation. We need to have three different models ready when computing losses: (1) pi, the network we're training; (2) pi-tilde, a slightly older version of pi (created to ensure training stability, similar to the periodic synchronization in deep Q-learning; (3) p_MLE, to compute rewards (but this can be pre-loaded in the form of input features, in case the GPU cannot fit the third model).

BART summarization generation fairseq issue

Given that there has been minor bugs with the fairseq BART summarization code (details on original fairseq github), we make the corresponding changes according to the fairseq authors' recommendation. (1) In ./fairseq/sequence_generator.py, see the modification here. (2) In ./fairseq/tasks/fairseq_task.py, see the modification here. (3) In ./fairseq/models/bart/hub_interface.py, see the modification here. The above is already implemented in this repo. But if we're reimplementing the GOLD code in the latest fairseq, we need to beware of this issue (and keep the three modifications in mind).

How to run?

Training

The entry point for training is ./fairseq_cli/train.py. See ./fairseq/options.py for possible flags. For CNN/DM, the script for running GOLD-p is provided in run_cnndm_goldp.sh; the script for running GOLD-s (which often performs better than GOLD-p) is provided in run_cnndm_golds.sh. Some other scripts for other tasks are also provided. For explanations of flags, please refer to ./fairseq/options.py as well as Algorithm 1 in the paper.

Validation

Note that to validate, one possibility is to find the checkpoint that corresponds to highest BLEU/ROUGE-2 score on dev set. We cannot validate according to NLL loss, given that in the paper, we showed that our models achieve higher accuracy but higher perplexity (and NLL loss). Do not use checkpoint_best.pt. IWSLT14 De-En validation is implemented. For summarization, please use run_cnndm_validation.py (similar to run_cnndm_inference.py) as an example to loop through all checkpoints. Then, compute the ROUGE based on run_cnndm_validation_step2.sh (perhaps with small modifications).

Evaluation/inference

For BART evaluation, we use the inference scripts provided in run_cnndm_inference.sh, run_xsum_inference.sh, run_squad_inference.sh. For IWSLT14 De-En inference, the following few lines will do.

python -W ignore [path-to-fairseq_cli/generate.py] data-bin/iwslt14.tokenized.de-en \
    --path [path-to-model-checkpoint.pt] \
    --batch-size 128 --beam 5 --remove-bpe --gen-subset test  > [path-to-save-to-file]

Transformer models

Please ensure the data is processed appropriately before using the models.

MLE model checkpoints

GOLD-s model checkpoints

Not a lot of hyperparameter search was done for the transformer models, so it is likely that more search (on hyperparameters, on architecture) could reach better performance.

Moreover, for summarization models, we use pyrouge+files2rouge to evaluate, based on the fairseq instructions after pyrouge+files2rouge installation. The package files2rouge has a common WordNet-2.0.exc.db error; see this link for the fix.

Citation, authors, and contact

The bibtex entry

Richard Yuanzhe Pang

He He

Several simple examples for popular neural network toolkits calling custom CUDA operators.

Neural Network CUDA Example Several simple examples for neural network toolkits (PyTorch, TensorFlow, etc.) calling custom CUDA operators. We provide

WeiYang 798 Jan 01, 2023
PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short-Term Transformer for Online Action Detection".

Long Short-Term Transformer for Online Action Detection Introduction This is a PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short

77 Dec 16, 2022
Code release for the ICML 2021 paper "PixelTransformer: Sample Conditioned Signal Generation".

PixelTransformer Code release for the ICML 2021 paper "PixelTransformer: Sample Conditioned Signal Generation". Project Page Installation Please insta

Shubham Tulsiani 24 Dec 17, 2022
Code release for General Greedy De-bias Learning

General Greedy De-bias for Dataset Biases This is an extention of "Greedy Gradient Ensemble for Robust Visual Question Answering" (ICCV 2021, Oral). T

4 Mar 15, 2022
Pocsploit is a lightweight, flexible and novel open source poc verification framework

Pocsploit is a lightweight, flexible and novel open source poc verification framework

cckuailong 208 Dec 24, 2022
Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions

README Repository containing the code for the paper "Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions". Specifically, an

Yousef Emam 13 Nov 24, 2022
Sound Event Detection with FilterAugment

Sound Event Detection with FilterAugment Official implementation of Heavily Augmented Sound Event Detection utilizing Weak Predictions (DCASE2021 Chal

43 Aug 28, 2022
Repo for EMNLP 2021 paper "Beyond Preserved Accuracy: Evaluating Loyalty and Robustness of BERT Compression"

beyond-preserved-accuracy Repo for EMNLP 2021 paper "Beyond Preserved Accuracy: Evaluating Loyalty and Robustness of BERT Compression" How to implemen

Kevin Canwen Xu 10 Dec 23, 2022
Sleep staging from ECG, assisted with EEG

Sleep_Staging_Knowledge Distillation This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep

2 Dec 12, 2022
Contrastive Learning of Structured World Models

Contrastive Learning of Structured World Models This repository contains the official PyTorch implementation of: Contrastive Learning of Structured Wo

Thomas Kipf 371 Jan 06, 2023
Fast Differentiable Matrix Sqrt Root

Fast Differentiable Matrix Sqrt Root Geometric Interpretation of Matrix Square Root and Inverse Square Root This repository constains the official Pyt

YueSong 42 Dec 30, 2022
Code repository for "Free View Synthesis", ECCV 2020.

Free View Synthesis Code repository for "Free View Synthesis", ECCV 2020. Setup Install the following Python packages in your Python environment - num

Intelligent Systems Lab Org 253 Dec 07, 2022
《K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters》(2020)

K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters This repository is the implementation of the paper "K-Adapter: Infusing Knowledge

Microsoft 118 Dec 13, 2022
Learning-based agent for Google Research Football

TiKick 1.Introduction Learning-based agent for Google Research Football Code accompanying the paper "TiKick: Towards Playing Multi-agent Football Full

Tsinghua AI Research Team for Reinforcement Learning 90 Dec 26, 2022
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

153 Dec 14, 2022
Action Segmentation Evaluation

Reference Action Segmentation Evaluation Code This repository contains the reference code for action segmentation evaluation. If you have a bug-fix/im

5 May 22, 2022
An efficient implementation of GPNN

Efficient-GPNN An efficient implementation of GPNN as depicted in "Drop the GAN: In Defense of Patches Nearest Neighbors as Single Image Generative Mo

7 Apr 16, 2022
Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

Stanford Machine Learning Group 34 Nov 16, 2022
Multitask Learning Strengthens Adversarial Robustness

Multitask Learning Strengthens Adversarial Robustness

Columbia University 15 Jun 10, 2022
This package proposes simplified exporting pytorch models to ONNX and TensorRT, and also gives some base interface for model inference.

PyTorch Infer Utils This package proposes simplified exporting pytorch models to ONNX and TensorRT, and also gives some base interface for model infer

Alex Gorodnitskiy 11 Mar 20, 2022