Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

Overview

Memformer - Pytorch

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attention, learned efficiently through Memory-Replay BackPropagation (MRBP) through time.

Install

$ pip install memformer

Usage

Full encoder / decoder, as in the paper

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    num_memory_slots = 128
)

src_seg_1 = torch.randint(0, 256, (1, 1024))
src_seg_2 = torch.randint(0, 256, (1, 1024))
src_seg_3 = torch.randint(0, 256, (1, 1024))

tgt = torch.randint(0, 256, (1, 1024))

enc_out1, mems1,    _ = model(src_seg_1) # (1, 1024, 512), (1, 128, 512), _
enc_out2, mems2,    _ = model(src_seg_2, mems = mems1)
enc_out3, mems3, loss = model(src_seg_3, tgt, mems = mems2)

loss.backward()

Encoder only

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_heads = 8,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    num_memory_slots = 128,
    num_mem_updates = 2,
    encoder_only = True       # only use encoder, in which output is encoded output
)

src1 = torch.randint(0, 256, (1, 1024))
src2 = torch.randint(0, 256, (1, 1024))

enc1, mems1 = model(src1) # (1, 1024, 512), (1, 128, 512)
enc2, mems2 = model(src2, mems = mems1)

Memory Replay Back-Propagation

import torch
from memformer import Memformer, memory_replay_backprop

model = Memformer(
    dim = 512,
    num_memory_slots = 128,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_max_seq_len = 1024
).cuda()

seq = torch.randint(0, 256, (1, 8192)).cuda()
seq_mask = torch.ones_like(seq).bool().cuda()

tgt = torch.randint(0, 256, (1, 512)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

# will automatically split the source sequence to 8 segments
memory_replay_backprop(
    model,
    src = seq,
    tgt = tgt,
    src_mask = seq_mask,
    tgt_mask = tgt_mask
)

Citations

@inproceedings{
    anonymous2021memformer,
    title={Memformer: The Memory-Augmented Transformer},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=_adSMszz_g9},
    note={under review}
}
You might also like...
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Styled Augmented Translation
Styled Augmented Translation

SAT Style Augmented Translation Introduction By collecting high-quality data, we were able to train a model that outperforms Google Translate on 6 dif

TANL: Structured Prediction as Translation between Augmented Natural Languages

TANL: Structured Prediction as Translation between Augmented Natural Languages Code for the paper "Structured Prediction as Translation between Augmen

A neuroanatomy-based augmented reality experience powered by computer vision. Features 3D visuals of the Atlas Brain Map slices.

Brain Augmented Reality (AR) A neuroanatomy-based augmented reality experience powered by computer vision that features 3D visuals of the Atlas Brain

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)
Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments [Project website] [Paper] This project is a PyTorch

A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)
A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)

Library | Paper | Slack We released two versions of OAG-BERT in CogDL package. OAG-BERT is a heterogeneous entity-augmented academic language model wh

DrQ-v2: Improved Data-Augmented Reinforcement Learning
DrQ-v2: Improved Data-Augmented Reinforcement Learning

DrQ-v2: Improved Data-Augmented RL Agent Method DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ,

[EMNLP 2021] Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training

RoSTER The source code used for Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training, p

 RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering
RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering

RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering Authors: Xi Ye, Semih Yavuz, Kazuma Hashimoto, Yingbo Zhou and

Comments
  • WIP - MemformerEncoder

    WIP - MemformerEncoder

    I´m always trying all your awesome work on transformers. My problem is NER on very large texts, with few examples.

    Memformer is the first one so far to converge faster and wield better accuracy than RNN encoders as LSTM, SRU and IndRNN It is ridiculously better than everything else I tested, congratulations @lucidrains 🥳

    I need to use the transformer as a Encoder in my pipeline, to feed a CRF layer. So I modified the code to accept an already embedded input, and to only do the Encode step.

    TODO:

    • [ ] Support Mask
    • [ ] Re-utilize code with Memformer class

    Is this within the scope of the project?

    opened by bratao 10
  • ETA on complete examples

    ETA on complete examples

    @lucidrains As I asked about the feedback-transformer, I was also wondering about this memformer implementation as I would love to try it. Any eta on any complete examples here? They will be much appreciated. Thanks.

    And similarly, I would love to see a simple example for custom line-by-line TXT datasets as well.

    Thank you again :)

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
A lightweight library to compare different PyTorch implementations of the same network architecture.

TorchBug is a lightweight library designed to compare two PyTorch implementations of the same network architecture. It allows you to count, and compar

Arjun Krishnakumar 5 Jan 02, 2023
Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation

Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation By Qiang Zhou*, Zilong Huang*, Lichao Huang, Han Shen, Yon

Forest 117 Apr 01, 2022
TraND: Transferable Neighborhood Discovery for Unsupervised Cross-domain Gait Recognition.

TraND This is the code for the paper "Jinkai Zheng, Xinchen Liu, Chenggang Yan, Jiyong Zhang, Wu Liu, Xiaoping Zhang and Tao Mei: TraND: Transferable

Jinkai Zheng 32 Apr 04, 2022
Deep Inside Convolutional Networks - This is a caffe implementation to visualize the learnt model

Deep Inside Convolutional Networks This is a caffe implementation to visualize the learnt model. Part of a class project at Georgia Tech Problem State

Jigar 61 Apr 15, 2022
Benchmarks for the Optimal Power Flow Problem

Power Grid Lib - Optimal Power Flow This benchmark library is curated and maintained by the IEEE PES Task Force on Benchmarks for Validation of Emergi

A Library of IEEE PES Power Grid Benchmarks 207 Dec 08, 2022
GeneralOCR is open source Optical Character Recognition based on PyTorch.

Introduction GeneralOCR is open source Optical Character Recognition based on PyTorch. It makes a fidelity and useful tool to implement SOTA models on

57 Dec 29, 2022
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 06, 2022
CSE-519---Project - Job Title Analysis (Project for CSE 519 - Data Science Fundamentals)

A Multifaceted Approach to Job Title Analysis CSE 519 - Data Science Fundamentals Project Description Project consists of three parts: Salary Predicti

Jimit Dholakia 1 Jan 04, 2022
Monocular 3D pose estimation. OpenVINO. CPU inference or iGPU (OpenCL) inference.

human-pose-estimation-3d-python-cpp RealSenseD435 (RGB) 480x640 + CPU Corei9 45 FPS (Depth is not used) 1. Run 1-1. RealSenseD435 (RGB) 480x640 + CPU

Katsuya Hyodo 8 Oct 03, 2022
Face Recognition plus identification simply and fast | Python

PyFaceDetection Face Recognition plus identification simply and fast Ubuntu Setup sudo pip3 install numpy sudo pip3 install cmake sudo pip3 install dl

Peyman Majidi Moein 16 Sep 22, 2022
A whale detector design for the Kaggle whale-detector challenge!

CNN (InceptionV1) + STFT based Whale Detection Algorithm So, this repository is my PyTorch solution for the Kaggle whale-detection challenge. The obje

Tarin Ziyaee 92 Sep 28, 2021
Single object tracking and segmentation.

Single/Multiple Object Tracking and Segmentation Codes and comparison of recent single/multiple object tracking and segmentation. News 💥 AutoMatch is

ZP ZHANG 385 Jan 02, 2023
Cross-modal Deep Face Normals with Deactivable Skip Connections

Cross-modal Deep Face Normals with Deactivable Skip Connections Victoria Fernández Abrevaya*, Adnane Boukhayma*, Philip H. S. Torr, Edmond Boyer (*Equ

72 Nov 27, 2022
KGDet: Keypoint-Guided Fashion Detection (AAAI 2021)

KGDet: Keypoint-Guided Fashion Detection (AAAI 2021) This is an official implementation of the AAAI-2021 paper "KGDet: Keypoint-Guided Fashion Detecti

Qian Shenhan 35 Dec 29, 2022
Materials for my scikit-learn tutorial

Scikit-learn Tutorial Jake VanderPlas email: [email protected] twitter: @jakevdp gith

Jake Vanderplas 1.6k Dec 30, 2022
This repository contains the implementation of the HealthGen model, a generative model to synthesize realistic EHR time series data with missingness

HealthGen: Conditional EHR Time Series Generation This repository contains the implementation of the HealthGen model, a generative model to synthesize

0 Jan 20, 2022
pytorchのスライス代入操作をonnxに変換する際にScatterNDならないようにするサンプル

pytorch_remove_ScatterND pytorchのスライス代入操作をonnxに変換する際にScatterNDならないようにするサンプル。 スライスしたtensorにそのまま代入してしまうとScatterNDになるため、計算結果をcatで新しいtensorにする。 python ver

2 Dec 01, 2022
Official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021.

Introduction This repository is the official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021. Data-free Kno

NVIDIA Research Projects 50 Jan 05, 2023
PyTorch implementation of Graph Convolutional Networks in Feature Space for Image Deblurring and Super-resolution, IJCNN 2021.

GCResNet PyTorch implementation of Graph Convolutional Networks in Feature Space for Image Deblurring and Super-resolution, IJCNN 2021. The code will

11 May 19, 2022