Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch


Memorizing Transformers - Pytorch

This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.


$ pip install memorizing-transformers-pytorch


import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,                 # number of tokens
    dim = 512,                          # dimension
    dim_head = 64,                      # dimension per attention head
    depth = 8,                          # number of layers
    memorizing_layers = (4, 5),         # which layers to have ANN memories
    max_knn_memories = 64000,           # maximum ANN memories to keep (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries)
    num_retrieved_memories = 32,        # number of ANN memories to retrieve
    clear_memories_on_sos_token_id = 1, # clear passed in ANN memories automatically for batch indices which contain this specified SOS token id - otherwise, you can also manually iterate through the ANN memories and clear the indices before the next iteration

data = torch.randint(0, 20000, (2, 1024)) # mock data

knn_memories = model.create_knn_memories(batch_size = 2) # create collection of KNN memories with the correct batch size (2 in example)

logits = model(data, knn_memories = knn_memories) # (1, 1024, 20000)

You can make the KNN memories read-only by setting add_knn_memory on forward to False


logits = model(data, knn_memories = knn_memories, add_knn_memory = False) # knn memories will not be updated

With Transformer-XL memories (only the memories that will be discarded will be added to the KNN memory)

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 8,
    memorizing_layers = (4, 5),
    max_knn_memories = 64000,
    num_retrieved_memories = 32,
    clear_memories_on_sos_token_id = 1,
    xl_memory_layers = (2, 3, 4, 5),      # xl memory layers - ( shows you do not need XL memory on all layers, just the latter ones) - if a KNNAttention layer ends up using XL memories, only the XL memories that will be discarded will be added to long term memory
    xl_max_memories = 512,                # number of xl memories to keep
    shift_knn_memories_down = 1,          # let a layer look at the KNN memories this number of layers above
    shift_xl_memories_down = 1,           # let a layer look at the XL memories this number of layers above, shown to enhance receptive field in ernie-doc paper

data = torch.randint(0, 20000, (2, 1024)) # mock data

xl_memories = None

with model.knn_memories_context(batch_size = 2) as knn_memories:
    logits1, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits2, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits3, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)

    # ... and so on

KNN Memory

This repository contains a wrapper around Faiss that can automatically store and retrieve key / values

import torch
from memorizing_transformers_pytorch import KNNMemory

memory = KNNMemory(
    dim = 64,                   # dimension of key / values
    max_memories = 64000,       # maximum number of memories to keep (will throw out the oldest memories for now if it overfills)
    num_indices = 2             # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document

memory.add(torch.randn(2, 512, 2, 64))  # (batch, seq, key | value, feature dim)
memory.add(torch.randn(2, 512, 2, 64))

memory.clear([0]) # clear batch 0, if it saw an <sos>

memory.add(torch.randn(2, 512, 2, 64))
memory.add(torch.randn(2, 512, 2, 64))

key_values, mask =, 512, 64), topk = 32)


Enwik8 training

$ python


  • switch to ivfhnsw and just remember all memories
  • enwik8 demo
  • validation for enwik8
  • solve gradient accumulation problem by offering some way to scope reads and writes to knn memories with another indices array
  • setup text generation with memories
  • figure out how to deal with memories efficiently once capacity has been hit
  • try to speed up reading and writing to knn memories collection with multiprocessing


Memory is Attention through Time - Alex Graves

  • Arguments to reproduce the models from the original paper?

    Hi lucidrains,

    This looks like excellent work! I have gone through the original paper and your repo, and am now trying to reproduce the model from the paper as closely as possible. Of course, the modifications you made such as hybrid attention instead of sigmoid gate are fine.

    Specifically, I would like to be able to try some of the variations in Table 4: image

    Suppose I'm interested in the 4th to last row with Context 512 Memory 8192 XL cache 512. Can you help me the model arguments to do that? Here is my initial attempt, with reference to Section 4.2:

    model = MemorizingTransformer(
        num_tokens = 32000, # vocab 32k
        dim = 1024, 
        depth = 12,
        memorizing_layers = 9,
        max_knn_memories = 8192, # Memory column
        num_retrieved_memories = 32,
        clear_memories_on_sos_token_id = 1,
        xl_memory_layers = (6, 7, 8, 9),  # not sure about this?
        xl_max_memories = 512, # XL cache column
        shift_knn_memories_down = 1, 
        shift_xl_memories_down = 1,
        # which argument corresponds to Context column?

    A second question is what are the model arguments to reproduce to first row of Table 4, with no memory nor XL cache? Thanks in advance.

    opened by manestay 1
  • KNNMemory add() does not appear to update self.knns

    KNNMemory add() does not appear to update self.knns

    Thanks for the nice implementation. I've adapted this code for my own use, so I don't have the whole stack that would reproduce this bug. However, you can check for yourself.

    The following code ought to update the KNN objects in the KNNMemory class:

    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
    Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))

    [link to that code here]

    However, even after repeated calls to add to the memory, calling results in empty values. If you view self.knns at this point, self.is_trained remains False.

    When I modify the code as follows, this fixes the issue.

    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
        return knn
    updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
    self.knns = updated_knns

    This will allow searches to return actual values.

    opened by vyaivo 0
  • FAISS hard reset

    FAISS hard reset

    Hello and thanks for this implementation!

    Do you know of any solutions to efficiently solve the "hard reset" problem in FAISS? I know that one could use IndexFlatL2 but that's not really efficient.

    Thank you!

    opened by itsdaniele 0
  •  index out of

    index out of

    when I run, error like this ,"index out of range: Tried to access index 10218 out of table with 255 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418"happens

    opened by chxiag 0
  • Support for Multi-GPU training?

    Support for Multi-GPU training?

    Thank you so much for the great implementation. I would like to ask whether your implementation for Memorizing Transformer could support multi-card distributed training like original paper. If you distribute the memorizingtrransformer model you created to each GPU, then every GPU would hold a memory with a retrieval faiss index. Therefore, each model on different GPU holds different memory database and retrieval index, which is different from the original paper. I regard that each model on different GPU should share the same retrieval context. This problem confuses me a lot.

    Thank you so much for your time. Looking forward to your response!

    opened by Victorwz 0
  • Dimensionality of key and values for Attention

    Dimensionality of key and values for Attention

    I have two questions about the key and value calculation in Attention (and similarly for KNNAttention).

    The relevant line is:

    1. Why is there only one Linear layer to_kv, instead of 2 linear layers to_k and to_v?
    2. Why is the last dimension dim_head*2? I get that *2 is for both k and v, but what about dim_head? I thought q, k, v should all have the same final dimension (i.e. inner_dim==dim_head*heads). My understanding is that this means that either a) there is only 1 attention head, or for b) all heads, k and v are shared. Is there a reason this is done, or am I misunderstanding?

    In your Attention class for Performer, q, k, v all have the same dimensions.

    Thanks in advance!

    opened by manestay 8
  • Maybe scale is wrong

    Maybe scale is wrong

    Shouldn't this be (1-scale)?

    opened by denadai2 3
