Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Overview

Memory Efficient Attention Pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

import torch
from memory_efficient_attention_pytorch import Attention

attn = Attention(
    dim = 512,
    dim_head = 64,                # dimension per head
    heads = 8,                    # number of attention heads
    causal = True,                # autoregressive or not
    memory_efficient = True,      # whether to use memory efficient attention (can be turned off to test against normal attention)
    q_bucket_size = 1024,         # bucket size along queries dimension
    k_bucket_size = 2048          # bucket size along key / values dimension
).cuda()

x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)

Cross attention

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    memory_efficient = True,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()

out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
  • benchmark and see how much torch jit helps
  • look at Triton and Keops and see if either can be a fit

Citations

@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • [feature request] Combining with flash attention?

    [feature request] Combining with flash attention?

    There is a new algorithm to optimize the qkv attention, https://github.com/HazyResearch/flash-attention https://arxiv.org/abs/2205.14135 It optimises the qkv attention part. Maybe you can look into integrating it with this.

    opened by Vbansal21 15
  • i did this, we could build on top

    i did this, we could build on top

    Hi there!

    It seems I did already some of the code... https://github.com/CHARM-Tx/linear_mem_attention_pytorch could we build on top of this? I talked to https://github.com/Chillee about an experimental functionality from functorch: https://github.com/pytorch/functorch that would allow for increased speed (mainly i want to match jax perofmance but its just difficult w/ pytorch imperative style).

    I would love to collaborate on this if you want!

    opened by hypnopump 5
  • Added dropout support to memory efficient variant

    Added dropout support to memory efficient variant

    Hey Phil,

    I have been using this repository for a project and I wanted to add dropout for completeness. I checked consistency with perceiver-ar impl.. I hope this is helpful.

    -Matt

    opened by usryokousha 2
  • Making this work with relative position bias from XTransformers

    Making this work with relative position bias from XTransformers

    Is there a way to make this work with RelativePositionBias. Currently this produces an attention bias of size $BHN^2$ where B is batch size, H is number of heads and N is input size. Can this be chunked and computed per chunk?

    opened by pfeatherstone 5
  •  save_for_backward can only save variables, but argument 5 is of type bool

    save_for_backward can only save variables, but argument 5 is of type bool

    Hi,

    Thank you for your indescribable work. I was trying to test your method specifically for cross-attention but It seems I get the error " save_for_backward can only save variables, but argument 5 is of type bool". I am not sure what I am doing wrong. I tried your own examples too but get the same error.

    Can you please help me out?

    Code:

    import torch from memory_efficient_attention_pytorch import Attention

    cross_attn = Attention( dim = 512, dim_head = 64, heads = 8, memory_efficient = True, q_bucket_size = 1024, k_bucket_size = 2048 ).cuda() (# out = sm_mod(inp1)) did this to avoid being a header x = torch.randn(1, 65536, 512).cuda() context = torch.randn(1, 65536, 512).cuda() (# mask = torch.ones(1, 65536).bool().cuda()) did this to avoid being a heading out = cross_attn(x

    ERROR:

    File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/main.py", line 45, in cli.main() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main run() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file runpy.run_path(target_as_str, run_name=compat.force_str("main")) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 265, in run_path return _run_module_code(code, init_globals, run_name, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 97, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/data/stars/user/abali/Phd_work/ISBI2023/X3D-Multigrid/CrossAttn_X3d_v2.py", line 872, in out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) print(out) File "/home/abali/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 215, in forward out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 127, in memory_efficient_attention exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( File "/home/abali/.local/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint return CheckpointFunction.apply(function, preserve, *args) TypeError: save_for_backward can only save variables, but argument 5 is of type bool

    opened by aliabid2243 1
  • Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/35559a05572f9d4eb982a8e2e399b40a2d61b85c/memory_efficient_attention_pytorch/memory_efficient_attention.py#L95

    Should this be: summarize_qkv_fn = summarize_qkv_chunk if needs_backwards else checkpointed_summarize_qkv_chunk instead of: summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    opened by vrobot 0
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network

DeepCDR Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network This work has been accepted to ECCB2020 and was also published in the

Qiao Liu 50 Dec 18, 2022
Alphabetical Letter Recognition

BayeesNetworks-Image-Classification Alphabetical Letter Recognition In these demo we are using "Bayees Networks" Our database is composed by Learning

Mohammed Firass 4 Nov 30, 2021
This repository contains a Ruby API for utilizing TensorFlow.

tensorflow.rb Description This repository contains a Ruby API for utilizing TensorFlow. Linux CPU Linux GPU PIP Mac OS CPU Not Configured Not Configur

somatic labs 825 Dec 26, 2022
This repository is for Competition for ML_data class

This repository is for Competition for ML_data class. Based on mmsegmentatoin,mainly using swin transformer to completed the competition.

jianlong 2 Oct 23, 2022
Invertible conditional GANs for image editing

Invertible Conditional GANs This is the implementation of the IcGAN model proposed in our paper: Invertible Conditional GANs for image editing. Novemb

Guim 278 Dec 12, 2022
unet-family: Ultimate version

unet-family: Ultimate version 基于之前my-unet代码,我整理出来了这一份终极版本unet-family,方便其他人阅读。 相比于之前的my-unet代码,代码分类更加规范,有条理 对于clone下来的代码不需要修改各种复杂繁琐的路径问题,直接就可以运行。 并且代码有

2 Sep 19, 2022
This is the official released code for our paper, The Emergence of Objectness: Learning Zero-Shot Segmentation from Videos

The-Emergence-of-Objectness This is the official released code for our paper, The Emergence of Objectness: Learning Zero-Shot Segmentation from Videos

44 Oct 08, 2022
DrNAS: Dirichlet Neural Architecture Search

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random va

Xiangning Chen 37 Jan 03, 2023
Official code for "Focal Self-attention for Local-Global Interactions in Vision Transformers"

Focal Transformer This is the official implementation of our Focal Transformer -- "Focal Self-attention for Local-Global Interactions in Vision Transf

Microsoft 486 Dec 20, 2022
Implementation of momentum^2 teacher

Momentum^2 Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning Requirements All experiments are done with python3.6, torch

jemmy li 121 Sep 26, 2022
Agent-based model simulator for air quality and pandemic risk assessment in architectural spaces

Agent-based model simulation for air quality and pandemic risk assessment in architectural spaces. User Guide archABM is a fast and open source agent-

Vicomtech 10 Dec 05, 2022
ANEA: Distant Supervision for Low-Resource Named Entity Recognition

ANEA: Distant Supervision for Low-Resource Named Entity Recognition ANEA is a tool to automatically annotate named entities in unlabeled text based on

Saarland University Spoken Language Systems Group 15 Mar 30, 2022
A lightweight library to compare different PyTorch implementations of the same network architecture.

TorchBug is a lightweight library designed to compare two PyTorch implementations of the same network architecture. It allows you to count, and compar

Arjun Krishnakumar 5 Jan 02, 2023
Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch

MeMOT - Pytorch (wip) Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch. This paper is just one in a line of work, but importan

Phil Wang 15 May 09, 2022
Bayesian Meta-Learning Through Variational Gaussian Processes

vmgp This is the repository of Vivek Myers and Nikhil Sardana for our CS 330 final project, Bayesian Meta-Learning Through Variational Gaussian Proces

Vivek Myers 2 Nov 17, 2022
CLNTM - Contrastive Learning for Neural Topic Model

Contrastive Learning for Neural Topic Model This repository contains the impleme

Thong Thanh Nguyen 25 Nov 24, 2022
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

Sayak Paul 54 Dec 08, 2022
Implementation of the final project of the course DDA6309 Probabilistic Graphical Model

Task-aware Joint CWS and POS (TCwsPos) This is the implementation of the final project of the course DDA6309 Probabilistic Graphical Models, The Chine

Peng 1 Dec 26, 2021
Tackling Obstacle Tower Challenge using PPO & A2C combined with ICM.

Obstacle Tower Challenge using Deep Reinforcement Learning Unity Obstacle Tower is a challenging realistic 3D, third person perspective and procedural

Zhuoyu Feng 5 Feb 10, 2022
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

Xingyu Lin 93 Jan 05, 2023