An implementation of Performer, a linear attention-based transformer, in Pytorch

Overview

Performer - Pytorch

PyPI version

An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random features approach (FAVOR+).

Install

$ pip install performer-pytorch

Usage

Performer Language Model

import torch
from performer_pytorch import PerformerLM

model = PerformerLM(
    num_tokens = 20000,
    max_seq_len = 2048,             # max sequence length
    dim = 512,                      # dimension
    depth = 12,                     # layers
    heads = 8,                      # heads
    causal = False,                 # auto-regressive or not
    nb_features = 256,              # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head
    feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
    generalized_attention = False,  # defaults to softmax approximation, but can be set to True for generalized attention
    kernel_fn = nn.ReLU(),          # the kernel function to be used, if generalized attention is turned on, defaults to Relu
    reversible = True,              # reversible layers, from Reformer paper
    ff_chunks = 10,                 # chunk feedforward layer, from Reformer paper
    use_scalenorm = False,          # use scale norm, from 'Transformers without Tears' paper
    use_rezero = False,             # use rezero, from 'Rezero is all you need' paper
    tie_embedding = False,          # multiply final embeddings with token weights for logits, like gpt decoder
    ff_glu = True,                  # use GLU variant for feedforward
    emb_dropout = 0.1,              # embedding dropout
    ff_dropout = 0.1,               # feedforward dropout
    attn_dropout = 0.1,             # post-attn dropout
    local_attn_heads = 4,           # 4 heads are local attention, 4 others are global performers
    local_window_size = 256,        # window size of local attention
    rotary_position_emb = True      # use rotary positional embedding, which endows linear attention with relative positional encoding with no learned parameters. should always be turned on unless if you want to go back to old absolute positional encoding
)

x = torch.randint(0, 20000, (1, 2048))
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 2048, 20000)

Plain Performer, if you are working with say images or other modalities

import torch
from performer_pytorch import Performer

model = Performer(
    dim = 512,
    depth = 1,
    heads = 8,
    causal = True
)

x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)

Encoder / Decoder - Made possible by Thomas Melistas

import torch
from performer_pytorch import PerformerEncDec

SRC_SEQ_LEN = 4096
TGT_SEQ_LEN = 4096
GENERATE_LEN = 512

enc_dec = PerformerEncDec(
    dim = 512,
    tie_token_embed = True,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = SRC_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = TGT_SEQ_LEN,
)

src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

# train
enc_dec.train()
loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask)
loss.backward()

# generate
generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with <bos> token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens

Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers.

import torch
from performer_pytorch import SelfAttention

attn = SelfAttention(
    dim = 512,
    heads = 8,
    causal = False,
).cuda()

x = torch.randn(1, 1024, 512).cuda()
attn(x) # (1, 1024, 512)

To minimize model surgery, you could also simply rewrite the code, so that the attention step is done by the FastAttention module, as follows.

import torch
from performer_pytorch import FastAttention

# queries / keys / values with heads already split and transposed to first dimension
# 8 heads, dimension of head is 64, sequence length of 512
q = torch.randn(1, 8, 512, 64)
k = torch.randn(1, 8, 512, 64)
v = torch.randn(1, 8, 512, 64)

attn_fn = FastAttention(
    dim_heads = 64,
    nb_features = 256,
    causal = False
)

out = attn_fn(q, k, v) # (1, 8, 512, 64)
# now merge heads and combine outputs with Wo

Advanced

At the end of training, if you wish to fix the projection matrices to get the model to output deterministically, you can invoke the following

model.fix_projection_matrices_()

Now your model will have fixed projection matrices across all layers

Citations

@misc{choromanski2020rethinking,
    title   = {Rethinking Attention with Performers},
    author  = {Krzysztof Choromanski and Valerii Likhosherstov and David Dohan and Xingyou Song and Andreea Gane and Tamas Sarlos and Peter Hawkins and Jared Davis and Afroz Mohiuddin and Lukasz Kaiser and David Belanger and Lucy Colwell and Adrian Weller},
    year    = {2020},
    eprint  = {2009.14794},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@inproceedings{katharopoulos_et_al_2020,
    author  = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
    title   = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
    booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
    year    = {2020}
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@techreport{zhuiyiroformer,
    title   = {RoFormer: Transformer with Rotary Position Embeddings - ZhuiyiAI},
    author  = {Jianlin Su},
    year    = {2021},
    url     = "https://github.com/ZhuiyiTechnology/roformer",
}
Comments
  • [Feature] EncoderDecoder framework, similar to ReformerEncDec

    [Feature] EncoderDecoder framework, similar to ReformerEncDec

    Hello Phil,

    Nice job on this great architecture. I want to use it as an Encoder Decoder within Deepspeed, so I am thinking of writing a wrapper similar to the one you did for Reformer. Do you have any tips on what to pay attention (no pun intended) and if I need to use padding as in Autopadder?

    Thanks

    opened by gulnazaki 22
  • Causal linear attention benchmark

    Causal linear attention benchmark

    First, thanks for this awesome repo!!

    Based on T5 model classes from Huggingface's transformers, I was trying to use performer attention instead of original T5 attention. We finetuned t5-large with summarization model, and tried to profile both time and memory usage, and compare the performer attention with the original attention. I have only benchmarked with input size of 1024.

    The result clearly showed that performer attention use lot less memory compared to the original transformer. I know from the paper that performer outperforms the original transformer when input size is bigger than 1024. However, finetuning and generation with the performer actually took longer, so I profiled the forward call of both the original T5 attention and the performer attention. The forward of T5 performer took twice longer and the main bottleneck was causal_dot_product_kernel from fast-transformers.

    Is this a normal performace of the performer or causal attention calculation? or Will the performer attention be faster with the bigger input size?

    opened by ice-americano 13
  • Regarding DDP and reversible networks

    Regarding DDP and reversible networks

    Hi, I'm trying to figure out how to combine DDP with setting the network to be reversible.

    My code basically looks like this:

    import pytorch_lightning as pl
    from performer_pytorch import Performer
    ...
    model = nn.Sequential([...,Performer(...,reversible=True)])
    trainer = pl.Trainer(...
                        distributed_backend='ddp',
                        ...)
    trainer.fit(model,train_loader,val_loader)
    

    Now all combinations work for me (ddp/not reversible, not ddp/reversible, not ddp/not reversible) except for ddp and reversible.

    The error I get is:

    RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons:

    1. Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes
    2. Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.

    I've seen multiple people have similar issues: https://github.com/huggingface/transformers/issues/7160 ,https://github.com/pytorch/pytorch/issues/46166 , https://github.com/tatp22/linformer-pytorch/issues/23

    Do you have any suggestion for how to deal with this issue? Im not really familiar with the inner workings of DDP and the autograd engine, so I'm not sure how to fix this myself.

    opened by Parskatt 11
  • Triangular matrices ?

    Triangular matrices ?

    Does the current implementation provide triangular matrices (to constrain the attention always on the "left" of the sequence, both for input and encoded values) as described in the last section of the original paper?

    opened by jeremycochoy 10
  • wrong implementation for autoregressive self-attention

    wrong implementation for autoregressive self-attention

    Hi, I found that you used fast_transfomers's CUDA Kernel, but it does not contain normalization part, which needs a cumsum outside the CausalDotProduct (in causal_linear_attention). If I didn't miss something, the result of your code should be wrong... But I am not 100% sure.

    opened by Sleepychord 10
  • There are no tests in this project, use_rezero=True is non-functional

    There are no tests in this project, use_rezero=True is non-functional

    Tests are needed to validate that models can train in various configurations. I built and ran simple tests (trying to get authorization to contribute as a PR) and found that use_rezero=True kills the gradient and results in a performer model that cannot learn. The fix consists in initializing the rezero parameter with a small value, but not zero (e.g., 1E-3 works in my tests). Zero prevents any signal to pass through the module so that the parameter will never change from zero.

    opened by fcampagne 10
  • Show what is the performance on enwiki8 is across your projects

    Show what is the performance on enwiki8 is across your projects

    Hello @lucidrains , I´m a very big fan of your work. It is of such as high quality, that every new project you release I get sleepless to try it.

    You do have many different versions of transformers, such as reformer, memory-xl, performer... And apparently you already test it with enwiki8.

    Would be possible to post on Read-me a table with the enwiki runtime, memory and some performance metric? That would be awesome to compare the different implementations.

    Thanks again for your work!!

    opened by bratao 10
  • Issue with biased estimates from QR decomposition

    Issue with biased estimates from QR decomposition

    Hi again :)

    See issue: https://github.com/google-research/google-research/issues/436 that I posted on the main repository. Using the QR incorrectly produces results with significantly higher variance. There is quite an easy fix by simply doing

     q, r = torch.qr(flattened)
     # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
     d = torch.diag(r, 0)
     ph = d.sign()
     q *= ph
    
    opened by Parskatt 9
  • Collaborate on Implementation?

    Collaborate on Implementation?

    I was planning on implementing this on Pytorch as well and started a repo https://github.com/calclavia/Performer-Pytorch Implemented the kernel so far. If the author(s) of this repo wants to collaborate, would be happy to contribute.

    opened by calclavia 9
  • Extra FF when using cross attention

    Extra FF when using cross attention

    Hello Phil,

    I have noticed that when using cross attention a new block (with attention and a FeedForward layer is added), while only an attention layer should be added between the self attention and the FF layer.

    Is there any reason for this?

    opened by gulnazaki 8
  • Add feature_redraw_interval option

    Add feature_redraw_interval option

    This fork allows the user to select a number of forward passes after which the random features will be redrawn. This allows us to avoid doing QR decomposition on the CPU every forward pass. By default it is set to redraw every 1000 passes.

    opened by norabelrose 8
  • Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Hi,

    First of all, this is a great package from lucidrains and I find it very helpful in my research.

    A quick question is that I noticed ViT-performer is slower than the regular ViT from lucidrains. For example running on mnist from pytorch will take 15 sec/epoch for regular ViT with the configuration below while ViT performer takes 23 sec/epoch.

    Checking the parameter count also shows ViT-performer has double the size of regular ViT.

    Screen Shot 2022-12-12 at 11 32 41 PM Screen Shot 2022-12-12 at 11 28 50 PM

    I am hoping that someone has intuition about the speed of ViT performer vs regular ViT and their parameter counts.

    Thank you very much in advance!

    opened by weihaosong 1
  • Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    As the title says, has anyone tried replacing multi head attention in a typical transformer with the self attention as described in this library.

    my thought was that I can essentially concat the multiple self attention elements together to replicate this per the attached image from the torch website. image

    I'm relatively new to transformers as a whole so hopefully this question makes some sense.

    for reference, considering the interest in a previous post, I've been attempting to explore performer effectiveness with DETR (https://github.com/facebookresearch/detr)

    thanks!

    opened by JGittles 0
  • Question about masking

    Question about masking

    Hi, thanks for the wonderful repo, I am new in BERT, so I 'd like to make sure in your example:

    model = PerformerLM() x = torch.randint(0, 20000, (1, 2048)) mask = torch.ones_like(x).bool() model(x, mask = mask) # (1, 2048, 20000)

    is this 'mask' is attention_mask? i.e., TRUE (1) for normal tokens and FALSE (0) for padding tokens? Or set 1 to indicate padding token? Thanks a lot!

    opened by Microbiods 1
  • Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Hi,

    Thanks for the amazing implementation. I'm wondering if Performer can be used like a set-operator (i.e. whether it is order equivariant)

    For instance, say I have a point cloud and I want to apply self-attention across all the point features. Can Performer be used here (note equivariance: points can be arbitrarily shuffled, but we expect the corresponding transformed features to be identical regardless of the shuffling)?

    Thanks!

    opened by nmakes 0
  • Using Performer with GNNs

    Using Performer with GNNs

    My understanding of "Rethinking Attention with Performers" is that FAVOR+ is used to approximate the attention matrix and avoids the use of the softmax function. In the README.md file, you note that the Plain Performer can be used if we are using images or other modalities, just as the authors elude to Performer's use in other areas.

    I am interested in using Perfomer to approximate attention between nodes in a graph neural network. The graph neural network contains vectors characterizing the node's features and boolean edge indices indicating a connection between two nodes.

    Do you have any recommendations how this is feasible with the current Performer model? I see that Attention.forward() contains input for a mask.

    opened by jah377 0
Releases(1.1.4)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Code release for "Bayesian Compression for Deep Learning" In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of

Karen Ullrich 190 Dec 30, 2022
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
PyTorch to TensorFlow Lite converter

PyTorch to TensorFlow Lite converter

Omer Ferhat Sarioglu 140 Dec 13, 2022
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

169 Dec 23, 2022