Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

Overview

Memformer - Pytorch

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attention, learned efficiently through Memory-Replay BackPropagation (MRBP) through time.

Install

$ pip install memformer

Usage

Full encoder / decoder, as in the paper

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    num_memory_slots = 128
)

src_seg_1 = torch.randint(0, 256, (1, 1024))
src_seg_2 = torch.randint(0, 256, (1, 1024))
src_seg_3 = torch.randint(0, 256, (1, 1024))

tgt = torch.randint(0, 256, (1, 1024))

enc_out1, mems1,    _ = model(src_seg_1) # (1, 1024, 512), (1, 128, 512), _
enc_out2, mems2,    _ = model(src_seg_2, mems = mems1)
enc_out3, mems3, loss = model(src_seg_3, tgt, mems = mems2)

loss.backward()

Encoder only

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_heads = 8,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    num_memory_slots = 128,
    num_mem_updates = 2,
    encoder_only = True       # only use encoder, in which output is encoded output
)

src1 = torch.randint(0, 256, (1, 1024))
src2 = torch.randint(0, 256, (1, 1024))

enc1, mems1 = model(src1) # (1, 1024, 512), (1, 128, 512)
enc2, mems2 = model(src2, mems = mems1)

Memory Replay Back-Propagation

import torch
from memformer import Memformer, memory_replay_backprop

model = Memformer(
    dim = 512,
    num_memory_slots = 128,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_max_seq_len = 1024
).cuda()

seq = torch.randint(0, 256, (1, 8192)).cuda()
seq_mask = torch.ones_like(seq).bool().cuda()

tgt = torch.randint(0, 256, (1, 512)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

# will automatically split the source sequence to 8 segments
memory_replay_backprop(
    model,
    src = seq,
    tgt = tgt,
    src_mask = seq_mask,
    tgt_mask = tgt_mask
)

Citations

@inproceedings{
    anonymous2021memformer,
    title={Memformer: The Memory-Augmented Transformer},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=_adSMszz_g9},
    note={under review}
}
You might also like...
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Styled Augmented Translation
Styled Augmented Translation

SAT Style Augmented Translation Introduction By collecting high-quality data, we were able to train a model that outperforms Google Translate on 6 dif

TANL: Structured Prediction as Translation between Augmented Natural Languages

TANL: Structured Prediction as Translation between Augmented Natural Languages Code for the paper "Structured Prediction as Translation between Augmen

A neuroanatomy-based augmented reality experience powered by computer vision. Features 3D visuals of the Atlas Brain Map slices.

Brain Augmented Reality (AR) A neuroanatomy-based augmented reality experience powered by computer vision that features 3D visuals of the Atlas Brain

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)
Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments [Project website] [Paper] This project is a PyTorch

A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)
A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)

Library | Paper | Slack We released two versions of OAG-BERT in CogDL package. OAG-BERT is a heterogeneous entity-augmented academic language model wh

DrQ-v2: Improved Data-Augmented Reinforcement Learning
DrQ-v2: Improved Data-Augmented Reinforcement Learning

DrQ-v2: Improved Data-Augmented RL Agent Method DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ,

[EMNLP 2021] Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training

RoSTER The source code used for Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training, p

 RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering
RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering

RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering Authors: Xi Ye, Semih Yavuz, Kazuma Hashimoto, Yingbo Zhou and

Comments
  • WIP - MemformerEncoder

    WIP - MemformerEncoder

    I´m always trying all your awesome work on transformers. My problem is NER on very large texts, with few examples.

    Memformer is the first one so far to converge faster and wield better accuracy than RNN encoders as LSTM, SRU and IndRNN It is ridiculously better than everything else I tested, congratulations @lucidrains 🥳

    I need to use the transformer as a Encoder in my pipeline, to feed a CRF layer. So I modified the code to accept an already embedded input, and to only do the Encode step.

    TODO:

    • [ ] Support Mask
    • [ ] Re-utilize code with Memformer class

    Is this within the scope of the project?

    opened by bratao 10
  • ETA on complete examples

    ETA on complete examples

    @lucidrains As I asked about the feedback-transformer, I was also wondering about this memformer implementation as I would love to try it. Any eta on any complete examples here? They will be much appreciated. Thanks.

    And similarly, I would love to see a simple example for custom line-by-line TXT datasets as well.

    Thank you again :)

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Improving Object Detection by Estimating Bounding Box Quality Accurately

Improving Object Detection by Estimating Bounding Box Quality Accurately Abstrac

2 Apr 14, 2022
PURE: End-to-End Relation Extraction

PURE: End-to-End Relation Extraction This repository contains (PyTorch) code and pre-trained models for PURE (the Princeton University Relation Extrac

Princeton Natural Language Processing 657 Jan 09, 2023
A curated (most recent) list of resources for Learning with Noisy Labels

A curated (most recent) list of resources for Learning with Noisy Labels

Jiaheng Wei 321 Jan 09, 2023
Forecasting directional movements of stock prices for intraday trading using LSTM and random forest

Forecasting directional movements of stock-prices for intraday trading using LSTM and random-forest https://arxiv.org/abs/2004.10178 Pushpendu Ghosh,

Pushpendu Ghosh 270 Dec 24, 2022
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

223 Dec 17, 2022
This repository contains all source code, pre-trained models related to the paper "An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator"

An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator This is a Pytorch implementation for the paper "An Empirical Study o

Cuong Nguyen 3 Nov 15, 2021
The offcial repository for 'CharacterBERT and Self-Teaching for Improving the Robustness of Dense Retrievers on Queries with Typos', SIGIR2022

CharacterBERT-DR The offcial repository for CharacterBERT and Self-Teaching for Improving the Robustness of Dense Retrievers on Queries with Typos, Sh

ielab 11 Nov 15, 2022
Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness through a Teacher-guided curriculum Learning Approach

Get Fooled for the Right Reason Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness throu

Sowrya Gali 1 Apr 25, 2022
ACV is a python library that provides explanations for any machine learning model or data.

ACV is a python library that provides explanations for any machine learning model or data. It gives local rule-based explanations for any model or data and different Shapley Values for tree-based mod

Salim Amoukou 85 Dec 27, 2022
Tensorflow 2 implementations of the C-SimCLR and C-BYOL self-supervised visual representation methods from "Compressive Visual Representations" (NeurIPS 2021)

Compressive Visual Representations This repository contains the source code for our paper, Compressive Visual Representations. We developed informatio

Google Research 30 Nov 23, 2022
Adaout is a practical and flexible regularization method with high generalization and interpretability

Adaout Adaout is a practical and flexible regularization method with high generalization and interpretability. Requirements python 3.6 (Anaconda versi

lambett 1 Feb 09, 2022
Official implementation for Multi-Modal Interaction Graph Convolutional Network for Temporal Language Localization in Videos

Multi-modal Interaction Graph Convolutioal Network for Temporal Language Localization in Videos Official implementation for Multi-Modal Interaction Gr

Zongmeng Zhang 15 Oct 18, 2022
PyTorch reimplementation of minimal-hand (CVPR2020)

Minimal Hand Pytorch Unofficial PyTorch reimplementation of minimal-hand (CVPR2020). you can also find in youtube or bilibili bare hand youtube or bil

Hao Meng 228 Dec 29, 2022
Code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectrograms, using the PyTorch Lightning.

stereoEEG2speech We provide code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectro

15 Nov 11, 2022
YOLOv5 Series Multi-backbone, Pruning and quantization Compression Tool Box.

YOLOv5-Compression Update News Requirements 环境安装 pip install -r requirements.txt Evaluation metric Visdrone Model mAP ZhangYuan 719 Jan 02, 2023

FedML: A Research Library and Benchmark for Federated Machine Learning

FedML: A Research Library and Benchmark for Federated Machine Learning 📄 https://arxiv.org/abs/2007.13518 News 2021-02-01 (Award): #NeurIPS 2020# Fed

FedML-AI 2.3k Jan 08, 2023
Dataset for the Research2Clinics @ NeurIPS 2021 Paper: What Do You See in this Patient? Behavioral Testing of Clinical NLP Models

Behavioral Testing of Clinical NLP Models This repository contains code for testing the behavior of clinical prediction models based on patient letter

Betty van Aken 2 Sep 20, 2022
🐤 Nix-TTS: An Incredibly Lightweight End-to-End Text-to-Speech Model via Non End-to-End Distillation

🐤 Nix-TTS An Incredibly Lightweight End-to-End Text-to-Speech Model via Non End-to-End Distillation Rendi Chevi, Radityo Eko Prasojo, Alham Fikri Aji

Rendi Chevi 156 Jan 09, 2023
This is the code of paper ``Contrastive Coding for Active Learning under Class Distribution Mismatch'' with python.

Contrastive Coding for Active Learning under Class Distribution Mismatch Official PyTorch implementation of ["Contrastive Coding for Active Learning u

21 Dec 22, 2022
A framework for using LSTMs to detect anomalies in multivariate time series data. Includes spacecraft anomaly data and experiments from the Mars Science Laboratory and SMAP missions.

Telemanom (v2.0) v2.0 updates: Vectorized operations via numpy Object-oriented restructure, improved organization Merge branches into single branch fo

Kyle Hundman 844 Dec 28, 2022