codes for "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation" (long paper of EMNLP-2022)

Overview

Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference)

Contents

Overview

We propose to conduct scheduled sampling based on decoding steps instead of the original training steps. We observe that our proposal can more realistically simulate the distribution of real translation errors, thus better bridging the gap between training and inference. The paper has been accepted to the main conference of EMNLP-2021.

Background

fastText

We conduct scheduled sampling for the Transformer with a two-pass decoder. An example of pseudo-code is as follows:

# first-pass: the same as the standard Transformer decoder
first_decoder_outputs = decoder(first_decoder_inputs)

# sampling tokens between model predicitions and ground-truth tokens
second_decoder_inputs = sampling_function(first_decoder_outputs, first_decoder_inputs)

# second-pass: computing the decoder again with the above sampled tokens
second_decoder_outputs = decoder(second_decoder_inputs)

Quick to Use

Our approaches are suitable for most autoregressive-based tasks. Please try the following pseudo-codes when conducting scheduled sampling:

import torch

def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths)
    '''
    conduct scheduled sampling based on the index of decoded tokens 
    param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections 
    param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens
    param max_seq_len: scalar, the max lengh of target sequence
    param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch
    '''

    # indexs of decoding steps
    t = torch.range(0, max_seq_len-1)

    # differenct sampling strategy based on decoding steps
    if sampling_strategy == "exponential":
        threshold_table = exp_radix ** t  
    elif sampling_strategy == "sigmoid":
        threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k ))
    elif sampling_strategy == "linear":        
        threshold_table = torch.max(epsilon, 1 - t / max_seq_len)
    else:
        ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy)

    # convert threshold_table to [batch_size, seq_len]
    threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril()
    thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len)
    thresholds = current_thresholds[:, :seq_len]

    # conduct sampling based on the above thresholds
    random_select_seed = torch.rand([batch_size, seq_len]) 
    second_decoder_inputs = torch.where(random_select_seed < thresholds, first_decoder_inputs, first_decoder_outputs)

    return second_decoder_inputs
    

Further Usage

Error accumulation is a common phenomenon in NLP tasks. Whenever you want to simulate the accumulation of errors, our method may come in handy. For examples:

# sampling tokens between noisy target tokens and ground-truth tokens
decoder_inputs = sampling_function(noisy_decoder_inputs, golden_decoder_inputs, max_seq_len, tgt_lengths)

# computing the decoder with the above sampled tokens
decoder_outputs = decoder(decoder_inputs)
# sampling utterences from model predictions and ground-truth utterences
contexts = sampling_function(predicted_utterences, golden_utterences, max_turns, current_turns)

model_predictions = dialogue_model(contexts, target_inputs)

Experiments

We provide scripts to reproduce the results in this paper(NMT and text summarization)

Citation

Please cite this paper if you find this repo useful.

@inproceedings{liu_ss_decoding_2021,
    title = "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation",
    author = "Liu, Yijin  and
      Meng, Fandong  and
      Chen, Yufeng  and
      Xu, Jinan  and
      Zhou, Jie",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
    year = "2021",
    address = "Online"
}

Contact

Please feel free to contact us ([email protected]) for any further questions.

Owner
Adaxry
Fast learner, eagle for new knowledge and deeper understanding
Adaxry
RIM: Reliable Influence-based Active Learning on Graphs.

RIM: Reliable Influence-based Active Learning on Graphs. This repository is the official implementation of RIM. Requirements To install requirements:

Wentao Zhang 4 Aug 29, 2022
DP-CL(Continual Learning with Differential Privacy)

DP-CL(Continual Learning with Differential Privacy) This is the official implementation of the Continual Learning with Differential Privacy. If you us

Phung Lai 3 Nov 04, 2022
How to Leverage Multimodal EHR Data for Better Medical Predictions?

How to Leverage Multimodal EHR Data for Better Medical Predictions? This repository contains the code of the paper: How to Leverage Multimodal EHR Dat

13 Dec 13, 2022
CVPR 2021

Smoothing the Disentangled Latent Style Space for Unsupervised Image-to-image Translation [Paper] | [Poster] | [Codes] Yahui Liu1,3, Enver Sangineto1,

Yahui Liu 37 Sep 12, 2022
Explainer for black box models that predict molecule properties

Explaining why that molecule exmol is a package to explain black-box predictions of molecules. The package uses model agnostic explanations to help us

White Laboratory 172 Dec 19, 2022
Frigate - NVR With Realtime Object Detection for IP Cameras

A complete and local NVR designed for HomeAssistant with AI object detection. Uses OpenCV and Tensorflow to perform realtime object detection locally for IP cameras.

Blake Blackshear 6.4k Dec 31, 2022
An AI Assistant More Than a Toolkit

tymon An AI Assistant More Than a Toolkit The reason for creating framework tymon is simple. making AI more like an assistant, helping us to complete

TymonXie 46 Oct 24, 2022
Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Dominik Klein 189 Dec 21, 2022
Understanding the Effects of Datasets Characteristics on Offline Reinforcement Learning

Understanding the Effects of Datasets Characteristics on Offline Reinforcement Learning Kajetan Schweighofer1, Markus Hofmarcher1, Marius-Constantin D

Institute for Machine Learning, Johannes Kepler University Linz 17 Dec 28, 2022
Implementation of Vaswani, Ashish, et al. "Attention is all you need."

Attention Is All You Need Paper Implementation This is my from-scratch implementation of the original transformer architecture from the following pape

Brando Koch 195 Dec 30, 2022
ImageBART: Bidirectional Context with Multinomial Diffusion for Autoregressive Image Synthesis

ImageBART NeurIPS 2021 Patrick Esser*, Robin Rombach*, Andreas Blattmann*, Björn Ommer * equal contribution arXiv | BibTeX | Poster Requirements A sui

CompVis Heidelberg 110 Jan 01, 2023
Cross View SLAM

Cross View SLAM This is the associated code and dataset repository for our paper I. D. Miller et al., "Any Way You Look at It: Semantic Crossview Loca

Ian D. Miller 99 Dec 09, 2022
A package related to building quasi-fibration symmetries

qf A package related to building quasi-fibration symmetries. If you'd like to learn more about how it works, see the brief explanation and References

Paolo Boldi 1 Dec 01, 2021
CTC segmentation python package

CTC segmentation CTC segmentation can be used to find utterances alignments within large audio files. This repository contains the ctc-segmentation py

Ludwig Kürzinger 217 Jan 04, 2023
Official implementation of "SinIR: Efficient General Image Manipulation with Single Image Reconstruction" (ICML 2021)

SinIR (Official Implementation) Requirements To install requirements: pip install -r requirements.txt We used Python 3.7.4 and f-strings which are in

47 Oct 11, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
A sequence of Jupyter notebooks featuring the 12 Steps to Navier-Stokes

CFD Python Please cite as: Barba, Lorena A., and Forsyth, Gilbert F. (2018). CFD Python: the 12 steps to Navier-Stokes equations. Journal of Open Sour

Barba group 2.6k Dec 30, 2022
On-device wake word detection powered by deep learning.

Porcupine Made in Vancouver, Canada by Picovoice Porcupine is a highly-accurate and lightweight wake word engine. It enables building always-listening

Picovoice 2.8k Dec 29, 2022
Scikit-learn compatible estimation of general graphical models

skggm : Gaussian graphical models using the scikit-learn API In the last decade, learning networks that encode conditional independence relationships

213 Jan 02, 2023
Point cloud processing tool library.

Point Cloud ToolBox This point cloud processing tool library can be used to process point clouds, 3d meshes, and voxels. Environment python 3.7.5 Dep

ZhangXinyun 40 Dec 09, 2022