Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

Overview

Memorizing Transformers - Pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.

Install

$ pip install memorizing-transformers-pytorch

Usage

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,                 # number of tokens
    dim = 512,                          # dimension
    dim_head = 64,                      # dimension per attention head
    depth = 8,                          # number of layers
    memorizing_layers = (4, 5),         # which layers to have ANN memories
    max_knn_memories = 64000,           # maximum ANN memories to keep (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries)
    num_retrieved_memories = 32,        # number of ANN memories to retrieve
    clear_memories_on_sos_token_id = 1, # clear passed in ANN memories automatically for batch indices which contain this specified SOS token id - otherwise, you can also manually iterate through the ANN memories and clear the indices before the next iteration
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

knn_memories = model.create_knn_memories(batch_size = 2) # create collection of KNN memories with the correct batch size (2 in example)

logits = model(data, knn_memories = knn_memories) # (1, 1024, 20000)

You can make the KNN memories read-only by setting add_knn_memory on forward to False

ex.

logits = model(data, knn_memories = knn_memories, add_knn_memory = False) # knn memories will not be updated

With Transformer-XL memories (only the memories that will be discarded will be added to the KNN memory)

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 8,
    memorizing_layers = (4, 5),
    max_knn_memories = 64000,
    num_retrieved_memories = 32,
    clear_memories_on_sos_token_id = 1,
    xl_memory_layers = (2, 3, 4, 5),      # xl memory layers - (https://arxiv.org/abs/2007.03356 shows you do not need XL memory on all layers, just the latter ones) - if a KNNAttention layer ends up using XL memories, only the XL memories that will be discarded will be added to long term memory
    xl_max_memories = 512,                # number of xl memories to keep
    shift_knn_memories_down = 1,          # let a layer look at the KNN memories this number of layers above
    shift_xl_memories_down = 1,           # let a layer look at the XL memories this number of layers above, shown to enhance receptive field in ernie-doc paper
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

xl_memories = None

with model.knn_memories_context(batch_size = 2) as knn_memories:
    logits1, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits2, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits3, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)

    # ... and so on

KNN Memory

This repository contains a wrapper around Faiss that can automatically store and retrieve key / values

import torch
from memorizing_transformers_pytorch import KNNMemory

memory = KNNMemory(
    dim = 64,                   # dimension of key / values
    max_memories = 64000,       # maximum number of memories to keep (will throw out the oldest memories for now if it overfills)
    num_indices = 2             # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document
)

memory.add(torch.randn(2, 512, 2, 64))  # (batch, seq, key | value, feature dim)
memory.add(torch.randn(2, 512, 2, 64))

memory.clear([0]) # clear batch 0, if it saw an <sos>

memory.add(torch.randn(2, 512, 2, 64))
memory.add(torch.randn(2, 512, 2, 64))

key_values, mask = memory.search(torch.randn(2, 512, 64), topk = 32)

Training

Enwik8 training

$ python train.py

Todo

  • switch to ivfhnsw and just remember all memories
  • enwik8 demo
  • validation for enwik8
  • solve gradient accumulation problem by offering some way to scope reads and writes to knn memories with another indices array
  • setup text generation with memories
  • figure out how to deal with memories efficiently once capacity has been hit
  • try to speed up reading and writing to knn memories collection with multiprocessing

Citations

@article{wu2022memorizing,
  title   = {Memorizing transformers},
  author  = {Wu, Yuhuai and Rabe, Markus N and Hutchins, DeLesley and Szegedy, Christian},
  journal = {arXiv preprint arXiv:2203.08913},
  year    = {2022}
}
@article{Shazeer2019FastTD,
  title   = {Fast Transformer Decoding: One Write-Head is All You Need},
  author  = {Noam M. Shazeer},
  journal = {ArXiv},
  year    = {2019},
  volume  = {abs/1911.02150}
}
@Article{AlphaFold2021,
  author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
  journal = {Nature},
  title   = {Highly accurate protein structure prediction with {AlphaFold}},
  year    = {2021},
  doi     = {10.1038/s41586-021-03819-2},
  note    = {(Accelerated article preview)},
}
@inproceedings{Rae2020DoTN,
  title   = {Do Transformers Need Deep Long-Range Memory?},
  author  = {Jack W. Rae and Ali Razavi},
  booktitle = {ACL},
  year    = {2020}
}
@misc{ding2021erniedoc,
  title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
  author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
  year    = {2021},
  eprint  = {2012.15688},
  archivePrefix = {arXiv},
  primaryClass = {cs.CL}
}
@misc{henry2020querykey,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
    year    = {2020},
    eprint  = {2010.04245},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

Memory is Attention through Time - Alex Graves

Comments
  • Arguments to reproduce the models from the original paper?

    Arguments to reproduce the models from the original paper?

    Hi lucidrains,

    This looks like excellent work! I have gone through the original paper and your repo, and am now trying to reproduce the model from the paper as closely as possible. Of course, the modifications you made such as hybrid attention instead of sigmoid gate are fine.

    Specifically, I would like to be able to try some of the variations in Table 4: image

    Suppose I'm interested in the 4th to last row with Context 512 Memory 8192 XL cache 512. Can you help me the model arguments to do that? Here is my initial attempt, with reference to Section 4.2:

    model = MemorizingTransformer(
        num_tokens = 32000, # vocab 32k
        dim = 1024, 
        depth = 12,
        memorizing_layers = 9,
        max_knn_memories = 8192, # Memory column
        num_retrieved_memories = 32,
        clear_memories_on_sos_token_id = 1,
        xl_memory_layers = (6, 7, 8, 9),  # not sure about this?
        xl_max_memories = 512, # XL cache column
        shift_knn_memories_down = 1, 
        shift_xl_memories_down = 1,
        # which argument corresponds to Context column?
    ).cuda()
    
    

    A second question is what are the model arguments to reproduce to first row of Table 4, with no memory nor XL cache? Thanks in advance.

    opened by manestay 1
  • KNNMemory add() does not appear to update self.knns

    KNNMemory add() does not appear to update self.knns

    Thanks for the nice implementation. I've adapted this code for my own use, so I don't have the whole stack that would reproduce this bug. However, you can check for yourself.

    The following code ought to update the KNN objects in the KNNMemory class:

    @delayed
    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
    
    Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
    

    [link to that code here]

    However, even after repeated calls to add to the memory, calling KNNMemory.search() results in empty values. If you view self.knns at this point, self.is_trained remains False.

    When I modify the code as follows, this fixes the issue.

    @delayed
    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
        return knn
    
    updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
    self.knns = updated_knns
    

    This will allow searches to return actual values.

    opened by vyaivo 0
  • FAISS hard reset

    FAISS hard reset

    Hello and thanks for this implementation!

    Do you know of any solutions to efficiently solve the "hard reset" problem in FAISS? I know that one could use IndexFlatL2 but that's not really efficient.

    Thank you!

    opened by itsdaniele 0
  •  index out of

    index out of

    when I run train.py, error like this ,"index out of range: Tried to access index 10218 out of table with 255 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418"happens

    opened by chxiag 0
  • Support for Multi-GPU training?

    Support for Multi-GPU training?

    Thank you so much for the great implementation. I would like to ask whether your implementation for Memorizing Transformer could support multi-card distributed training like original paper. If you distribute the memorizingtrransformer model you created to each GPU, then every GPU would hold a memory with a retrieval faiss index. Therefore, each model on different GPU holds different memory database and retrieval index, which is different from the original paper. I regard that each model on different GPU should share the same retrieval context. This problem confuses me a lot.

    Thank you so much for your time. Looking forward to your response!

    opened by Victorwz 0
  • Dimensionality of key and values for Attention

    Dimensionality of key and values for Attention

    I have two questions about the key and value calculation in Attention (and similarly for KNNAttention).

    The relevant line is: https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L135

    1. Why is there only one Linear layer to_kv, instead of 2 linear layers to_k and to_v?
    2. Why is the last dimension dim_head*2? I get that *2 is for both k and v, but what about dim_head? I thought q, k, v should all have the same final dimension (i.e. inner_dim==dim_head*heads). My understanding is that this means that either a) there is only 1 attention head, or for b) all heads, k and v are shared. Is there a reason this is done, or am I misunderstanding?

    In your Attention class for Performer, q, k, v all have the same dimensions.

    Thanks in advance!

    opened by manestay 8
  • Maybe scale is wrong

    Maybe scale is wrong

    https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L237

    Shouldn't this be (1-scale)?

    opened by denadai2 3
Releases(0.3.10)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
超轻量级bert的pytorch版本,大量中文注释,容易修改结构,持续更新

bert4pytorch 2021年8月27更新: 感谢大家的star,最近有小伙伴反映了一些小的bug,我也注意到了,奈何这个月工作上实在太忙,更新不及时,大约会在9月中旬集中更新一个只需要pip一下就完全可用的版本,然后会新添加一些关键注释。 再增加对抗训练的内容,更新一个完整的finetune

muqiu 317 Dec 18, 2022
Translation for Trilium Notes. Trilium Notes 中文版.

Trilium Translation 中文说明 This repo provides a translation for the awesome Trilium Notes. Currently, I have translated Trilium Notes into Chinese. Test

743 Jan 08, 2023
Simple Text-Generator with OpenAI gpt-2 Pytorch Implementation

GPT2-Pytorch with Text-Generator Better Language Models and Their Implications Our model, called GPT-2 (a successor to GPT), was trained simply to pre

Tae-Hwan Jung 775 Jan 08, 2023
Unofficial Implementation of Zero-Shot Text-to-Speech for Text-Based Insertion in Audio Narration

Zero-Shot Text-to-Speech for Text-Based Insertion in Audio Narration This repo contains only model Implementation of Zero-Shot Text-to-Speech for Text

Rishikesh (ऋषिकेश) 33 Sep 22, 2022
基于“Seq2Seq+前缀树”的知识图谱问答

KgCLUE-bert4keras 基于“Seq2Seq+前缀树”的知识图谱问答 简介 博客:https://kexue.fm/archives/8802 环境 软件:bert4keras=0.10.8 硬件:目前的结果是用一张Titan RTX(24G)跑出来的。 运行 第一次运行的时候,会给知

苏剑林(Jianlin Su) 65 Dec 12, 2022
Easy to use, state-of-the-art Neural Machine Translation for 100+ languages

EasyNMT - Easy to use, state-of-the-art Neural Machine Translation This package provides easy to use, state-of-the-art machine translation for more th

Ubiquitous Knowledge Processing Lab 748 Jan 06, 2023
Community and sentiment analysis based on tweets

The project has set itself the goal of analyzing the thoughts and interaction of Italian users through the social posts expressed through the Twitter platform on the day of the entry into force of th

3 Nov 17, 2022
Conversational text Analysis using various NLP techniques

Conversational text Analysis using various NLP techniques

Rita Anjana 159 Jan 06, 2023
PyTorch Implementation of the paper Single Image Texture Translation for Data Augmentation

SITT The repo contains official PyTorch Implementation of the paper Single Image Texture Translation for Data Augmentation. Authors: Boyi Li Yin Cui T

Boyi Li 52 Jan 05, 2023
SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch.

The goal is to create a single, flexible, and user-friendly toolkit that can be used to easily develop state-of-the-art speech technologies, including systems for speech recognition, speaker recognit

SpeechBrain 5.1k Jan 09, 2023
A PyTorch implementation of paper "Learning Shared Semantic Space for Speech-to-Text Translation", ACL (Findings) 2021

Chimera: Learning Shared Semantic Space for Speech-to-Text Translation This is a Pytorch implementation for the "Chimera" paper Learning Shared Semant

Chi Han 43 Dec 28, 2022
An Explainable Leaderboard for NLP

ExplainaBoard: An Explainable Leaderboard for NLP Introduction | Website | Download | Backend | Paper | Video | Bib Introduction ExplainaBoard is an i

NeuLab 319 Dec 20, 2022
Open-Source Toolkit for End-to-End Speech Recognition leveraging PyTorch-Lightning and Hydra.

🤗 Contributing to OpenSpeech 🤗 OpenSpeech provides reference implementations of various ASR modeling papers and three languages recipe to perform ta

Openspeech TEAM 513 Jan 03, 2023
Simple, Pythonic, text processing--Sentiment analysis, part-of-speech tagging, noun phrase extraction, translation, and more.

TextBlob: Simplified Text Processing Homepage: https://textblob.readthedocs.io/ TextBlob is a Python (2 and 3) library for processing textual data. It

Steven Loria 8.4k Dec 26, 2022
Code Generation using a large neural network called GPT-J

CodeGenX is a Code Generation system powered by Artificial Intelligence! It is delivered to you in the form of a Visual Studio Code Extension and is Free and Open-source!

DeepGenX 389 Dec 31, 2022
CorNet Correlation Networks for Extreme Multi-label Text Classification

CorNet Correlation Networks for Extreme Multi-label Text Classification Prerequisites python==3.6.3 pytorch==1.2.0 torchgpipe==0.0.5 click==7.0 ruamel

Guangxu Xun 38 Dec 31, 2022
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

ELECTRA Introduction ELECTRA is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using

Google Research 2.1k Dec 28, 2022
Universal End2End Training Platform, including pre-training, classification tasks, machine translation, and etc.

背景 安装教程 快速上手 (一)预训练模型 (二)机器翻译 (三)文本分类 TenTrans 进阶 1. 多语言机器翻译 2. 跨语言预训练 背景 TrenTrans是一个统一的端到端的多语言多任务预训练平台,支持多种预训练方式,以及序列生成和自然语言理解任务。 安装教程 git clone git

Tencent Minority-Mandarin Translation Team 42 Dec 20, 2022
GAP-text2SQL: Learning Contextual Representations for Semantic Parsing with Generation-Augmented Pre-Training

GAP-text2SQL: Learning Contextual Representations for Semantic Parsing with Generation-Augmented Pre-Training Code and model from our AAAI 2021 paper

Amazon Web Services - Labs 83 Jan 09, 2023
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

EleutherAI 96 Dec 21, 2022