An implementation of Performer, a linear attention-based transformer, in Pytorch


Performer - Pytorch

PyPI version

An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random features approach (FAVOR+).


$ pip install performer-pytorch


Performer Language Model

import torch
from performer_pytorch import PerformerLM

model = PerformerLM(
    num_tokens = 20000,
    max_seq_len = 2048,             # max sequence length
    dim = 512,                      # dimension
    depth = 12,                     # layers
    heads = 8,                      # heads
    causal = False,                 # auto-regressive or not
    nb_features = 256,              # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head
    feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
    generalized_attention = False,  # defaults to softmax approximation, but can be set to True for generalized attention
    kernel_fn = nn.ReLU(),          # the kernel function to be used, if generalized attention is turned on, defaults to Relu
    reversible = True,              # reversible layers, from Reformer paper
    ff_chunks = 10,                 # chunk feedforward layer, from Reformer paper
    use_scalenorm = False,          # use scale norm, from 'Transformers without Tears' paper
    use_rezero = False,             # use rezero, from 'Rezero is all you need' paper
    tie_embedding = False,          # multiply final embeddings with token weights for logits, like gpt decoder
    ff_glu = True,                  # use GLU variant for feedforward
    emb_dropout = 0.1,              # embedding dropout
    ff_dropout = 0.1,               # feedforward dropout
    attn_dropout = 0.1,             # post-attn dropout
    local_attn_heads = 4,           # 4 heads are local attention, 4 others are global performers
    local_window_size = 256,        # window size of local attention
    rotary_position_emb = True      # use rotary positional embedding, which endows linear attention with relative positional encoding with no learned parameters. should always be turned on unless if you want to go back to old absolute positional encoding

x = torch.randint(0, 20000, (1, 2048))
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 2048, 20000)

Plain Performer, if you are working with say images or other modalities

import torch
from performer_pytorch import Performer

model = Performer(
    dim = 512,
    depth = 1,
    heads = 8,
    causal = True

x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)

Encoder / Decoder - Made possible by Thomas Melistas

import torch
from performer_pytorch import PerformerEncDec

SRC_SEQ_LEN = 4096
TGT_SEQ_LEN = 4096

enc_dec = PerformerEncDec(
    dim = 512,
    tie_token_embed = True,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = SRC_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = TGT_SEQ_LEN,

src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

# train
loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask)

# generate
generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with <bos> token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens

Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers.

import torch
from performer_pytorch import SelfAttention

attn = SelfAttention(
    dim = 512,
    heads = 8,
    causal = False,

x = torch.randn(1, 1024, 512).cuda()
attn(x) # (1, 1024, 512)

To minimize model surgery, you could also simply rewrite the code, so that the attention step is done by the FastAttention module, as follows.

import torch
from performer_pytorch import FastAttention

# queries / keys / values with heads already split and transposed to first dimension
# 8 heads, dimension of head is 64, sequence length of 512
q = torch.randn(1, 8, 512, 64)
k = torch.randn(1, 8, 512, 64)
v = torch.randn(1, 8, 512, 64)

attn_fn = FastAttention(
    dim_heads = 64,
    nb_features = 256,
    causal = False

out = attn_fn(q, k, v) # (1, 8, 512, 64)
# now merge heads and combine outputs with Wo


At the end of training, if you wish to fix the projection matrices to get the model to output deterministically, you can invoke the following


Now your model will have fixed projection matrices across all layers


