Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

Overview

FLASH - Pytorch

Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time

Install

$ pip install FLASH-pytorch

Usage

The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.

It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.

import torch
from flash_pytorch import GAU

gau = GAU(
    dim = 512,
    query_key_dim = 128,     # query / key dimension
    causal = True,           # autoregressive or not
    expansion_factor = 2,    # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)

The authors then combine GAU with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.

This combination of the quadratic gated attention unit with grouped linear attention they named FLASH

You can also use this quite easily

import torch
from flash_pytorch import FLASH

flash = FLASH(
    dim = 512,
    group_size = 256,             # group size
    causal = True,                # autoregressive or not
    query_key_dim = 128,          # query / key dimension
    expansion_factor = 2.         # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1111, 512)     # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)

Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).

import torch
from flash_pytorch import FLASHTransformer

model = FLASHTransformer(
    num_tokens = 20000,          # number of tokens
    dim = 512,                   # model dimension
    depth = 12,                  # depth
    causal = True,               # autoregressive or not
    group_size = 256,            # size of the groups
    query_key_dim = 128,         # dimension of queries / keys
    expansion_factor = 2.,       # hidden dimension = dim * expansion_factor
    norm_type = 'scalenorm',     # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
    shift_tokens = True          # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)

Test on Autoregressive Enwik8

$ python train.py

Citations

@article{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.10447}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
Comments
  • einsum operation in Linear Attention Part

    einsum operation in Linear Attention Part

    Hi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343

    lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
    lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
    

    the lin_kv is three-dim (bde) And the code in the paper is

    lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask) 
    linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)
    

    the lin_kv is four-dim (bgke) It seems that the two ways are not equivalent.

    Looking forward to your reply. Best,

    opened by ShomyLiu 5
  • mask error

    mask error

    x = torch.randint(0, 20000, (1, 1024))
    mask = x.ne(0)
    logits = model(x, mask=mask)
    

    RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 2

    opened by keyunluo 1
  • Speed on TPU

    Speed on TPU

    Hi, Thanks for the code! I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.

    opened by magicknight 0
  • About the

    About the "shift_tokens"

    Thank you for your amazing code.

    In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following: if self.shift_tokens: x_shift, x_pass = normed_x.chunk(2, dim = -1) x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) normed_x = torch.cat((x_shift, x_pass), dim = -1)

    Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.

    In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.

    May I know why it works?

    Kang

    opened by kangzhao2 1
  • Cross-Attention?

    Cross-Attention?

    Hi, @lucidrains. Thank you for sharing this excellent implementation with us all! Do you have any thoughts as to what changes would need to be made to make cross-attention possible with your FLASH model?

    opened by amorehead 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Analysis code and Latex source of the manuscript describing the conditional permutation test of confounding bias in predictive modelling.

Git repositoty of the manuscript entitled Statistical quantification of confounding bias in predictive modelling by Tamas Spisak The manuscript descri

PNI - Predictive Neuroimaging Lab, University Hospital Essen, Germany 0 Nov 22, 2021
This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

TransUNet This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation Usage

1.4k Jan 04, 2023
Repository of continual learning papers

Continual learning paper repository This repository contains an incomplete (but dynamically updated) list of papers exploring continual learning in ma

29 Jan 05, 2023
An AutoML Library made with Optuna and PyTorch Lightning

An AutoML Library made with Optuna and PyTorch Lightning Installation Recommended pip install -U gradsflow From source pip install git+https://github.

GradsFlow 294 Dec 17, 2022
YOLOV4运行在嵌入式设备上

在嵌入式设备上实现YOLO V4 tiny 在嵌入式设备上实现YOLO V4 tiny 目录结构 目录结构 |-- YOLO V4 tiny |-- .gitignore |-- LICENSE |-- README.md |-- test.txt |-- t

Liu-Wei 6 Sep 09, 2021
This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object Tracking with TRansformer.

MOTR: End-to-End Multiple-Object Tracking with TRansformer This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object

348 Jan 07, 2023
Fully Convlutional Neural Networks for state-of-the-art time series classification

Deep Learning for Time Series Classification As the simplest type of time series data, univariate time series provides a reasonably good starting poin

Stephen 572 Dec 23, 2022
[ICLR'21] FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

FedBN: Federated Learning on Non-IID Features via Local Batch Normalization This is the PyTorch implemention of our paper FedBN: Federated Learning on

<a href=[email protected]"> 156 Dec 15, 2022
TensorFlow 2 AI/ML library wrapper for openFrameworks

ofxTensorFlow2 This is an openFrameworks addon for the TensorFlow 2 ML (Machine Learning) library

Center for Art and Media Karlsruhe 96 Dec 31, 2022
Deep Reinforcement Learning by using an on-policy adaptation of Maximum a Posteriori Policy Optimization (MPO)

V-MPO Simple code to demonstrate Deep Reinforcement Learning by using an on-policy adaptation of Maximum a Posteriori Policy Optimization (MPO) in Pyt

Nugroho Dewantoro 9 Jun 06, 2022
TensorFlow Tutorials with YouTube Videos

TensorFlow Tutorials Original repository on GitHub Original author is Magnus Erik Hvass Pedersen Introduction These tutorials are intended for beginne

9.1k Jan 02, 2023
DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021)

DPT This repo is the official implementation of DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021). We provide code and model

CASIA-IVA-Lab 111 Dec 21, 2022
Implementation of Kronecker Attention in Pytorch

Kronecker Attention Pytorch Implementation of Kronecker Attention in Pytorch. Results look less than stellar, but if someone found some context where

Phil Wang 16 May 06, 2022
QuALITY: Question Answering with Long Input Texts, Yes!

QuALITY: Question Answering with Long Input Texts, Yes! Authors: Richard Yuanzhe Pang,* Alicia Parrish,* Nitish Joshi,* Nikita Nangia, Jason Phang, An

ML² AT CILVR 61 Jan 02, 2023
An Unbiased Learning To Rank Algorithms (ULTRA) toolbox

Unbiased Learning to Rank Algorithms (ULTRA) This is an Unbiased Learning To Rank Algorithms (ULTRA) toolbox, which provides a codebase for experiment

back 3 Nov 18, 2022
The official implementation of the IEEE S&P`22 paper "SoK: How Robust is Deep Neural Network Image Classification Watermarking".

Watermark-Robustness-Toolbox - Official PyTorch Implementation This repository contains the official PyTorch implementation of the following paper to

49 Dec 19, 2022
Implementation of GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation (ICLR 2022).

GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation [OpenReview] [arXiv] [Code] The official implementation of GeoDiff: A Geome

Minkai Xu 155 Dec 26, 2022
Comp445 project - Data Communications & Computer Networks

COMP-445 Data Communications & Computer Networks Change Python version in Conda

Peng Zhao 2 Oct 03, 2022
This code uses generative adversarial networks to generate diverse task allocation plans for Multi-agent teams.

Mutli-agent task allocation This code uses generative adversarial networks to generate diverse task allocation plans for Multi-agent teams. To change

Biorobotics Lab 5 Oct 12, 2022