Sequence-to-Sequence learning using PyTorch

Overview

Seq2Seq in PyTorch

This is a complete suite for training sequence-to-sequence models in PyTorch. It consists of several models and code to both train and infer using them.

Using this code you can train:

  • Neural-machine-translation (NMT) models
  • Language models
  • Image to caption generation
  • Skip-thought sentence representations
  • And more...

Installation

git clone --recursive https://github.com/eladhoffer/seq2seq.pytorch
cd seq2seq.pytorch; python setup.py develop

Models

Models currently available:

Datasets

Datasets currently available:

All datasets can be tokenized using 3 available segmentation methods:

  • Character based segmentation
  • Word based segmentation
  • Byte-pair-encoding (BPE) as suggested by bpe with selectable number of tokens.

After choosing a tokenization method, a vocabulary will be generated and saved for future inference.

Training methods

The models can be trained using several methods:

  • Basic Seq2Seq - given encoded sequence, generate (decode) output sequence. Training is done with teacher-forcing.
  • Multi Seq2Seq - where several tasks (such as multiple languages) are trained simultaneously by using the data sequences as both input to the encoder and output for decoder.
  • Image2Seq - used to train image to caption generators.

Usage

Example training scripts are available in scripts folder. Inference examples are available in examples folder.

  • example for training a transformer on WMT16 according to original paper regime:
DATASET=${1:-"WMT16_de_en"}
DATASET_DIR=${2:-"./data/wmt16_de_en"}
OUTPUT_DIR=${3:-"./results"}

WARMUP="4000"
LR0="512**(-0.5)"

python main.py \
  --save transformer \
  --dataset ${DATASET} \
  --dataset-dir ${DATASET_DIR} \
  --results-dir ${OUTPUT_DIR} \
  --model Transformer \
  --model-config "{'num_layers': 6, 'hidden_size': 512, 'num_heads': 8, 'inner_linear': 2048}" \
  --data-config "{'moses_pretok': True, 'tokenization':'bpe', 'num_symbols':32000, 'shared_vocab':True}" \
  --b 128 \
  --max-length 100 \
  --device-ids 0 \
  --label-smoothing 0.1 \
  --trainer Seq2SeqTrainer \
  --optimization-config "[{'step_lambda':
                          \"lambda t: { \
                              'optimizer': 'Adam', \
                              'lr': ${LR0} * min(t ** -0.5, t * ${WARMUP} ** -1.5), \
                              'betas': (0.9, 0.98), 'eps':1e-9}\"
                          }]"
  • example for training attentional LSTM based model with 3 layers in both encoder and decoder:
python main.py \
  --save de_en_wmt17 \
  --dataset ${DATASET} \
  --dataset-dir ${DATASET_DIR} \
  --results-dir ${OUTPUT_DIR} \
  --model RecurrentAttentionSeq2Seq \
  --model-config "{'hidden_size': 512, 'dropout': 0.2, \
                   'tie_embedding': True, 'transfer_hidden': False, \
                   'encoder': {'num_layers': 3, 'bidirectional': True, 'num_bidirectional': 1, 'context_transform': 512}, \
                   'decoder': {'num_layers': 3, 'concat_attention': True,\
                               'attention': {'mode': 'dot_prod', 'dropout': 0, 'output_transform': True, 'output_nonlinearity': 'relu'}}}" \
  --data-config "{'moses_pretok': True, 'tokenization':'bpe', 'num_symbols':32000, 'shared_vocab':True}" \
  --b 128 \
  --max-length 80 \
  --device-ids 0 \
  --trainer Seq2SeqTrainer \
  --optimization-config "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3},
                          {'epoch': 6, 'lr': 5e-4},
                          {'epoch': 8, 'lr':1e-4},
                          {'epoch': 10, 'lr': 5e-5},
                          {'epoch': 12, 'lr': 1e-5}]" \
Owner
Elad Hoffer
Elad Hoffer
Homepage of paper: Paint Transformer: Feed Forward Neural Painting with Stroke Prediction, ICCV 2021.

Paint Transformer: Feed Forward Neural Painting with Stroke Prediction [Paper] [PaddlePaddle Implementation] Homepage of paper: Paint Transformer: Fee

442 Dec 16, 2022
​TextWorld is a sandbox learning environment for the training and evaluation of reinforcement learning (RL) agents on text-based games.

TextWorld A text-based game generator and extensible sandbox learning environment for training and testing reinforcement learning (RL) agents. Also ch

Microsoft 983 Dec 23, 2022
Reinforcement Learning for finance

Reinforcement Learning for Finance We apply reinforcement learning for stock trading. Fetch Data Example import utils # fetch symbols from yahoo fina

Tomoaki Fujii 159 Jan 03, 2023
A distributed deep learning framework that supports flexible parallelization strategies.

FlexFlow FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization stra

528 Dec 25, 2022
Image morphing without reference points by applying warp maps and optimizing over them.

Differentiable Morphing Image morphing without reference points by applying warp maps and optimizing over them. Differentiable Morphing is machine lea

Alex K 380 Dec 19, 2022
Code for LIGA-Stereo Detector, ICCV'21

LIGA-Stereo Introduction This is the official implementation of the paper LIGA-Stereo: Learning LiDAR Geometry Aware Representations for Stereo-based

Xiaoyang Guo 75 Dec 09, 2022
Malmo Collaborative AI Challenge - Team Pig Catcher

The Malmo Collaborative AI Challenge - Team Pig Catcher Approach The challenge involves 2 agents who can either cooperate or defect. The optimal polic

Kai Arulkumaran 66 Jun 29, 2022
Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)

Self-Supervised Models are Continual Learners This is the official repository for the paper: Self-Supervised Models are Continual Learners Enrico Fini

Enrico Fini 73 Dec 18, 2022
Fashion Entity Classification

Fashion-Entity-Classification - Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grays

ADITYA SHAH 1 Jan 04, 2022
Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions

README Repository containing the code for the paper "Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions". Specifically, an

Yousef Emam 13 Nov 24, 2022
HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow

Class HiddenMarkovModel HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow 2.0 Installatio

Susara Thenuwara 2 Nov 03, 2021
AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation

AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation AniGAN: Style-Guided Generative Adversarial Networks for U

Bing Li 81 Dec 14, 2022
A simple library that implements CLIP guided loss in PyTorch.

pytorch_clip_guided_loss: Pytorch implementation of the CLIP guided loss for Text-To-Image, Image-To-Image, or Image-To-Text generation. A simple libr

Sergei Belousov 74 Dec 26, 2022
Repository aimed at compiling code, papers, demos etc.. related to my PhD on 3D vision and machine learning for fruit detection and shape estimation at the university of Lincoln

PhD_3DPerception Repository aimed at compiling code, papers, demos etc.. related to my PhD on 3D vision and machine learning for fruit detection and s

lelouedec 2 Oct 06, 2022
PyTorch Lightning + Hydra. A feature-rich template for rapid, scalable and reproducible ML experimentation with best practices. ⚡🔥⚡

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

Łukasz Zalewski 2.1k Jan 09, 2023
KoCLIP: Korean port of OpenAI CLIP, in Flax

KoCLIP This repository contains code for KoCLIP, a Korean port of OpenAI's CLIP. This project was conducted as part of Hugging Face's Flax/JAX communi

Jake Tae 100 Jan 02, 2023
Discriminative Region Suppression for Weakly-Supervised Semantic Segmentation

Discriminative Region Suppression for Weakly-Supervised Semantic Segmentation (AAAI 2021) Official pytorch implementation of our paper: Discriminative

Beom 74 Dec 27, 2022
Baseline inference Algorithm for the STOIC2021 challenge.

STOIC2021 Baseline Algorithm This codebase contains an example submission for the STOIC2021 COVID-19 AI Challenge. As a baseline algorithm, it impleme

Luuk Boulogne 10 Aug 08, 2022
Implementation for "Conditional entropy minimization principle for learning domain invariant representation features"

Implementation for "Conditional entropy minimization principle for learning domain invariant representation features". The code is reproduced from thi

1 Nov 02, 2022
Source code for our Paper "Learning in High-Dimensional Feature Spaces Using ANOVA-Based Matrix-Vector Multiplication"

NFFT4ANOVA Source code for our Paper "Learning in High-Dimensional Feature Spaces Using ANOVA-Based Matrix-Vector Multiplication" This package uses th

Theresa Wagner 1 Aug 10, 2022