Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

Overview

Memformer - Pytorch

Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attention, learned efficiently through Memory-Replay BackPropagation (MRBP) through time.

Install

$ pip install memformer

Usage

Full encoder / decoder, as in the paper

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_heads = 8,
    dec_max_seq_len = 1024,
    num_memory_slots = 128
)

src_seg_1 = torch.randint(0, 256, (1, 1024))
src_seg_2 = torch.randint(0, 256, (1, 1024))
src_seg_3 = torch.randint(0, 256, (1, 1024))

tgt = torch.randint(0, 256, (1, 1024))

enc_out1, mems1,    _ = model(src_seg_1) # (1, 1024, 512), (1, 128, 512), _
enc_out2, mems2,    _ = model(src_seg_2, mems = mems1)
enc_out3, mems3, loss = model(src_seg_3, tgt, mems = mems2)

loss.backward()

Encoder only

import torch
from memformer import Memformer

model = Memformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_heads = 8,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    num_memory_slots = 128,
    num_mem_updates = 2,
    encoder_only = True       # only use encoder, in which output is encoded output
)

src1 = torch.randint(0, 256, (1, 1024))
src2 = torch.randint(0, 256, (1, 1024))

enc1, mems1 = model(src1) # (1, 1024, 512), (1, 128, 512)
enc2, mems2 = model(src2, mems = mems1)

Memory Replay Back-Propagation

import torch
from memformer import Memformer, memory_replay_backprop

model = Memformer(
    dim = 512,
    num_memory_slots = 128,
    enc_num_tokens = 256,
    enc_depth = 2,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 2,
    dec_max_seq_len = 1024
).cuda()

seq = torch.randint(0, 256, (1, 8192)).cuda()
seq_mask = torch.ones_like(seq).bool().cuda()

tgt = torch.randint(0, 256, (1, 512)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

# will automatically split the source sequence to 8 segments
memory_replay_backprop(
    model,
    src = seq,
    tgt = tgt,
    src_mask = seq_mask,
    tgt_mask = tgt_mask
)

Citations

@inproceedings{
    anonymous2021memformer,
    title={Memformer: The Memory-Augmented Transformer},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=_adSMszz_g9},
    note={under review}
}
You might also like...
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Styled Augmented Translation
Styled Augmented Translation

SAT Style Augmented Translation Introduction By collecting high-quality data, we were able to train a model that outperforms Google Translate on 6 dif

TANL: Structured Prediction as Translation between Augmented Natural Languages

TANL: Structured Prediction as Translation between Augmented Natural Languages Code for the paper "Structured Prediction as Translation between Augmen

A neuroanatomy-based augmented reality experience powered by computer vision. Features 3D visuals of the Atlas Brain Map slices.

Brain Augmented Reality (AR) A neuroanatomy-based augmented reality experience powered by computer vision that features 3D visuals of the Atlas Brain

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)
Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments [Project website] [Paper] This project is a PyTorch

A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)
A heterogeneous entity-augmented academic language model based on Open Academic Graph (OAG)

Library | Paper | Slack We released two versions of OAG-BERT in CogDL package. OAG-BERT is a heterogeneous entity-augmented academic language model wh

DrQ-v2: Improved Data-Augmented Reinforcement Learning
DrQ-v2: Improved Data-Augmented Reinforcement Learning

DrQ-v2: Improved Data-Augmented RL Agent Method DrQ-v2 is a model-free off-policy algorithm for image-based continuous control. DrQ-v2 builds on DrQ,

[EMNLP 2021] Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training

RoSTER The source code used for Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training, p

 RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering
RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering

RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering Authors: Xi Ye, Semih Yavuz, Kazuma Hashimoto, Yingbo Zhou and

Comments
  • WIP - MemformerEncoder

    WIP - MemformerEncoder

    I´m always trying all your awesome work on transformers. My problem is NER on very large texts, with few examples.

    Memformer is the first one so far to converge faster and wield better accuracy than RNN encoders as LSTM, SRU and IndRNN It is ridiculously better than everything else I tested, congratulations @lucidrains 🥳

    I need to use the transformer as a Encoder in my pipeline, to feed a CRF layer. So I modified the code to accept an already embedded input, and to only do the Encode step.

    TODO:

    • [ ] Support Mask
    • [ ] Re-utilize code with Memformer class

    Is this within the scope of the project?

    opened by bratao 10
  • ETA on complete examples

    ETA on complete examples

    @lucidrains As I asked about the feedback-transformer, I was also wondering about this memformer implementation as I would love to try it. Any eta on any complete examples here? They will be much appreciated. Thanks.

    And similarly, I would love to see a simple example for custom line-by-line TXT datasets as well.

    Thank you again :)

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Code for 2021 NeurIPS --- Towards Multi-Grained Explainability for Graph Neural Networks

ReFine: Multi-Grained Explainability for GNNs This is the official code for Towards Multi-Grained Explainability for Graph Neural Networks (NeurIPS 20

Shirley (Ying-Xin) Wu 47 Dec 16, 2022
REGTR: End-to-end Point Cloud Correspondences with Transformers

REGTR: End-to-end Point Cloud Correspondences with Transformers This repository contains the source code for REGTR. REGTR utilizes multiple transforme

Zi Jian Yew 108 Dec 17, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling For Official repo of NU-Wave: A Diffusion Probabilistic Model for Neural Audio Up

Rishikesh (ऋषिकेश) 38 Oct 11, 2022
This Deep Learning Model Predicts that from which disease you are suffering.

Deep-Learning-Project This Deep Learning Model Predicts that from which disease you are suffering. This Project Covers the Topics of Deep Learning Int

Jai Viral Doshi 0 Jan 20, 2022
【steal piano】GitHub偷情分析工具!

【steal piano】GitHub偷情分析工具! 你是否有这样的困扰,有一天你的仓库被很多人加了star,但是你却不知道这些人都是从哪来的? 别担心,GitHub偷情分析工具帮你轻松解决问题! 原理 GitHub偷情分析工具透过分析star的时间以及他们之间的follow关系,可以推测出每个st

黄巍 442 Dec 21, 2022
NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

5 Nov 03, 2022
Label Mask for Multi-label Classification

LM-MLC 一种基于完型填空的多标签分类算法 1 前言 本文主要介绍本人在全球人工智能技术创新大赛【赛道一】设计的一种基于完型填空(模板)的多标签分类算法:LM-MLC,该算法拟合能力很强能感知标签关联性,在多个数据集上测试表明该算法与主流算法无显著性差异,在该比赛数据集上的dev效果很好,但是由

52 Nov 20, 2022
A simple API wrapper for Discord interactions.

Your ultimate Discord interactions library for discord.py. About | Installation | Examples | Discord | PyPI About What is discord-py-interactions? dis

james 641 Jan 03, 2023
Unrolled Generative Adversarial Networks

Unrolled Generative Adversarial Networks Luke Metz, Ben Poole, David Pfau, Jascha Sohl-Dickstein arxiv:1611.02163 This repo contains an example notebo

Ben Poole 292 Dec 06, 2022
Implementation of light baking system for ray tracing based on Activision's UberBake

Vulkan Light Bakary MSU Graphics Group Student's Diploma Project Treefonov Andrey [GitHub] [LinkedIn] Project Goal The goal of the project is to imple

Andrey Treefonov 7 Dec 27, 2022
Replication package for the manuscript "Using Personality Detection Tools for Software Engineering Research: How Far Can We Go?" submitted to TOSEM

tosem2021-personality-rep-package Replication package for the manuscript "Using Personality Detection Tools for Software Engineering Research: How Far

Collaborative Development Group 1 Dec 13, 2021
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
Husein pet projects in here!

project-suka-suka Husein pet projects in here! List of projects mysejahtera-density. Generate resolution points using meshgrid and request each points

HUSEIN ZOLKEPLI 47 Dec 09, 2022
Face Detection & Age Gender & Expression & Recognition

Face Detection & Age Gender & Expression & Recognition

Sajjad Ayobi 188 Dec 28, 2022
Adabelief-Optimizer - Repository for NeurIPS 2020 Spotlight "AdaBelief Optimizer: Adapting stepsizes by the belief in observed gradients"

AdaBelief Optimizer NeurIPS 2020 Spotlight, trains fast as Adam, generalizes well as SGD, and is stable to train GANs. Release of package We have rele

Juntang Zhuang 998 Dec 29, 2022
SCAN: Learning to Classify Images without Labels, incl. SimCLR. [ECCV 2020]

Learning to Classify Images without Labels This repo contains the Pytorch implementation of our paper: SCAN: Learning to Classify Images without Label

Wouter Van Gansbeke 1.1k Dec 30, 2022
A clean and extensible PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

A clean and extensible PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners A PyTorch re-implementation of Mask Autoencoder trai

Tianyu Hua 23 Dec 13, 2022
Official PyTorch implementation of the paper: Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting.

Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting Official PyTorch implementation of the paper: Improving Graph Neural Net

Giorgos Bouritsas 58 Dec 31, 2022
Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis

Readme File for "Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis" by Ham, Imai, and Janson. (2022) All scripts were written and

0 Jan 27, 2022
MIMIC Code Repository: Code shared by the research community for the MIMIC-III database

MIMIC Code Repository The MIMIC Code Repository is intended to be a central hub for sharing, refining, and reusing code used for analysis of the MIMIC

MIT Laboratory for Computational Physiology 1.8k Dec 26, 2022