Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

Overview

FLASH - Pytorch

Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time

Install

$ pip install FLASH-pytorch

Usage

The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.

It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.

import torch
from flash_pytorch import GAU

gau = GAU(
    dim = 512,
    query_key_dim = 128,     # query / key dimension
    causal = True,           # autoregressive or not
    expansion_factor = 2,    # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)

The authors then combine GAU with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.

This combination of the quadratic gated attention unit with grouped linear attention they named FLASH

You can also use this quite easily

import torch
from flash_pytorch import FLASH

flash = FLASH(
    dim = 512,
    group_size = 256,             # group size
    causal = True,                # autoregressive or not
    query_key_dim = 128,          # query / key dimension
    expansion_factor = 2.         # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1111, 512)     # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)

Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).

import torch
from flash_pytorch import FLASHTransformer

model = FLASHTransformer(
    num_tokens = 20000,          # number of tokens
    dim = 512,                   # model dimension
    depth = 12,                  # depth
    causal = True,               # autoregressive or not
    group_size = 256,            # size of the groups
    query_key_dim = 128,         # dimension of queries / keys
    expansion_factor = 2.,       # hidden dimension = dim * expansion_factor
    norm_type = 'scalenorm',     # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
    shift_tokens = True          # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)

Test on Autoregressive Enwik8

$ python train.py

Citations

@article{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.10447}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
Comments
  • einsum operation in Linear Attention Part

    einsum operation in Linear Attention Part

    Hi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343

    lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
    lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
    

    the lin_kv is three-dim (bde) And the code in the paper is

    lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask) 
    linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)
    

    the lin_kv is four-dim (bgke) It seems that the two ways are not equivalent.

    Looking forward to your reply. Best,

    opened by ShomyLiu 5
  • mask error

    mask error

    x = torch.randint(0, 20000, (1, 1024))
    mask = x.ne(0)
    logits = model(x, mask=mask)
    

    RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 2

    opened by keyunluo 1
  • Speed on TPU

    Speed on TPU

    Hi, Thanks for the code! I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.

    opened by magicknight 0
  • About the

    About the "shift_tokens"

    Thank you for your amazing code.

    In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following: if self.shift_tokens: x_shift, x_pass = normed_x.chunk(2, dim = -1) x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) normed_x = torch.cat((x_shift, x_pass), dim = -1)

    Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.

    In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.

    May I know why it works?

    Kang

    opened by kangzhao2 1
  • Cross-Attention?

    Cross-Attention?

    Hi, @lucidrains. Thank you for sharing this excellent implementation with us all! Do you have any thoughts as to what changes would need to be made to make cross-attention possible with your FLASH model?

    opened by amorehead 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Learning Open-World Object Proposals without Learning to Classify

Learning Open-World Object Proposals without Learning to Classify Pytorch implementation for "Learning Open-World Object Proposals without Learning to

Dahun Kim 149 Dec 22, 2022
Source code for "Pack Together: Entity and Relation Extraction with Levitated Marker"

PL-Marker Source code for Pack Together: Entity and Relation Extraction with Levitated Marker. Quick links Overview Setup Install Dependencies Data Pr

THUNLP 173 Dec 30, 2022
Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations, CVPR 2019 (Oral)

Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations The code of: Weakly Supervised Learning of Instance Segmentation with I

Jiwoon Ahn 472 Dec 29, 2022
Official implementation of "StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation" (SIGGRAPH 2021)

StyleCariGAN in PyTorch Official implementation of StyleCariGAN:Caricature Generation via StyleGAN Feature Map Modulation in PyTorch Requirements PyTo

PeterZhouSZ 49 Oct 31, 2022
Dark Finix: All in one hacking framework with almost 100 tools

Dark Finix - Hacking Framework. Dark Finix is a all in one hacking framework wit

Md. Nur habib 2 Feb 18, 2022
Configure SRX interfaces with Scrapli

Configure SRX interfaces with Scrapli Overview This example will show how to configure interfaces on Juniper's SRX firewalls. In addition to the Pytho

Calvin Remsburg 1 Jan 07, 2022
Steer OpenAI's Jukebox with Music Taggers

TagBox Steer OpenAI's Jukebox with Music Taggers! The closest thing we have to VQGAN+CLIP for music! Unsupervised Source Separation By Steering Pretra

Ethan Manilow 34 Nov 02, 2022
An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

0 May 06, 2022
This GitHub repo consists of Code and Some results of project- Diabetes Treatment using Gold nanoparticles. These Consist of ML Models used for prediction Diabetes and further the basic theory and working of Gold nanoparticles.

GoldNanoparticles This GitHub repo consists of Code and Some results of project- Diabetes Treatment using Gold nanoparticles. These Consist of ML Mode

1 Jan 30, 2022
A 2D Visual Localization Framework based on Essential Matrices [ICRA2020]

A 2D Visual Localization Framework based on Essential Matrices This repository provides implementation of our paper accepted at ICRA: To Learn or Not

Qunjie Zhou 27 Nov 07, 2022
SPEAR: Semi suPErvised dAta progRamming

Semi-Supervised Data Programming for Data Efficient Machine Learning SPEAR is a library for data programming with semi-supervision. The package implem

decile-team 91 Dec 06, 2022
Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition

Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition Official implementation of the Efficient Conforme

Maxime Burchi 145 Dec 30, 2022
Churn prediction

Churn-prediction Churn-prediction Data preprocessing:: Label encoder is used to normalize the categorical variable Data Transformation:: For each data

1 Sep 28, 2022
Anderson Acceleration for Deep Learning

Anderson Accelerated Deep Learning (AADL) AADL is a Python package that implements the Anderson acceleration to speed-up the training of deep learning

Oak Ridge National Laboratory 7 Nov 24, 2022
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

551 Dec 29, 2022
Circuit Training: An open-source framework for generating chip floor plans with distributed deep reinforcement learning

Circuit Training: An open-source framework for generating chip floor plans with distributed deep reinforcement learning. Circuit Training is an open-s

Google Research 479 Dec 25, 2022
Code for Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights

Piggyback: https://arxiv.org/abs/1801.06519 Pretrained masks and backbones are available here: https://uofi.box.com/s/c5kixsvtrghu9yj51yb1oe853ltdfz4q

Arun Mallya 165 Nov 22, 2022
MERLOT: Multimodal Neural Script Knowledge Models

merlot MERLOT: Multimodal Neural Script Knowledge Models MERLOT is a model for learning what we are calling "neural script knowledge" -- representatio

Rowan Zellers 190 Dec 22, 2022
Differentiable rasterization applied to 3D model simplification tasks

nvdiffmodeling Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Automatic 3D Model

NVIDIA Research Projects 336 Dec 30, 2022
Explore the Expression: Facial Expression Generation using Auxiliary Classifier Generative Adversarial Network

Explore the Expression: Facial Expression Generation using Auxiliary Classifier Generative Adversarial Network This is the official implementation of

azad 2 Jul 09, 2022