Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Overview

Memory Efficient Attention

arXiv PyPI version

This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch.

Implementation is almost same as the one proposed in the paper, with additional masking and adding bias compatibility, batch dimensions support and PyTorch implementation. For computing attention, the proposed method requires only O(sqrt(n)) memory, and the provided functions can be used as a drop-in replacement for attention calculation.

Important Note: This implementation is a trade-off between memory requirements and runtime, so you should adjust key_chunk_size and query_chunk_size parameters to achieve the best configuration for your usecase. Here is a note from the paper's authors:

While a constant chunk size for the queries and a chunk size of sqrt(n) for the keys and values is optimal for memory consumption, the runtime is also affected by the choice of chunk size in practice, which is heavily affected by the choice of hardware. Ultimately, we have to leave this trade-off to the programmer, and expose the chunk sizes as arguments query_chunk_size and key_chunk_size. In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact (on TPUv2), while still providing significant memory savings.

Quick Start

  1. Install the library
# for Jax
pip install memory-efficient-attention[jax]
# for PyTorch
pip install memory-efficient-attention[torch]
# for Running Tests
pip install memory-efficient-attention[testing]
  1. Compute attention with the proper function
0.5 bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100 # Adjust chunk sizes efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)">
import numpy as np
# for PyTorch
from memory_efficient_attention import efficient_dot_product_attention_pt
# or for Jax
from memory_efficient_attention import efficient_dot_product_attention_jax

# Random Data (batch dimensions are not necessary)
b = 8
query = np.random.rand(1, b, 128, 16, 8).astype("float32")
key = np.random.rand(1, b, 128, 16, 8).astype("float32")
value = np.random.rand(1, b, 128, 16, 8).astype("float32")
# optional, for casual tasks, ...
mask = np.random.rand(1, b, 16, 128, 128) > 0.5
bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100

# Adjust chunk sizes        
efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)

Citation

Please cite if this implementation helps your research. You can use the following BibTeX entry:

@misc{memory_efficient_attention,
	title = {Memory Efficient Attention},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/memory-efficient-attention}},
	year = {2021}
}

Also, for the paper:

@misc{rabe2021selfattention,
      title={Self-attention Does Not Need $O(n^2)$ Memory}, 
      author={Markus N. Rabe and Charles Staats},
      year={2021},
      eprint={2112.05682},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
You might also like...
Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory"

memory_efficient_attention.pytorch A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory" (Rabe&Staats'21). def effic

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

Implementation of
Implementation of "Efficient Regional Memory Network for Video Object Segmentation" (Xie et al., CVPR 2021).

RMNet This repository contains the source code for the paper Efficient Regional Memory Network for Video Object Segmentation. Cite this work @inprocee

Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation
Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation

STCN Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [a

Comments
  • feat: output_attentions

    feat: output_attentions

    I'm looking into hacking some of the models in the transformers library to use this library for attention, and I don't see a way to support output_attentions yet. This is a flag passed in transformers, where the attention weights are preserved and returned to the user, if it is set.

    I looked a little at implementing this in the torch backend, and I note the scan() function provides for only a single tensor return value. It seems to me that scan() function would be most clearly replaced by a for loop, but it could also be modified to handle tuples, or return_weights could be handled via accessing nonlocal data in some way instead of returning them through the chunk scanner. I'm also not sure how the output would best be passed to the user.

    Edit: Draft implementation 01/28 at https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:faba6371ac7faaa2040a2c26e15ae7ab87f94ce4 . I ended up extending the scan function for parity between implementations. Edit 2: Turns out it's the postsoftmax attention weights, not the presoftmax attention weights. I've updated this post and the draft implementation for this output: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:return_weights

    opened by xloem 4
  • Provide a flag for the user to receive attention weights

    Provide a flag for the user to receive attention weights

    This is my draft code for #1. I saw this feature in the transformers library and wanted to implement it here.

    I'm curious what you think about this feature and implementation.

    The code is simply slightly instrumented so that the final attention weights can be returned to the user. Tests are augmented to test this use. In utils, the scan function is expanded to handle tuples.

    A change to dynamic_slice crept in from dev, to use slices rather than index_slice. I've retained this change because it looks like it would execute faster to me, but it can be removed.

    Rebased and squashed from 84724e1de4721ea0333d6bdbb91e8bce74fbeac .

    opened by xloem 2
  • Improve performance via batched-matmul and fused multiplies

    Improve performance via batched-matmul and fused multiplies

    Many thanks for providing this reference implementation.

    I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
    https://github.com/Birch-san/diffusers/pull/1/commits/04372140a25d7f53549175f1f196599c3e9bf3a5

    Knowing that computing attention via baddbmm()+bmm() can outperform einsum by 18%: I tried to rewrite the algorithm to use those.

    I compared the speed of my optimized version, against the implementation in this repository. this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).

    compared to the implementation in this repository:
    my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).

    here's my optimized implementation:
    https://github.com/Birch-san/diffusers/pull/1

    batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head].

    code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head] format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head] format favoured by my implementation like so:

    query = query.transpose(1,2).flatten(end_dim=1)
    key = key.transpose(1,2).flatten(end_dim=1)
    value = value.transpose(1,2).flatten(end_dim=1)
    

    the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head] format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head] format like so:

    result.unflatten(0, (-1, attn.heads)).transpose(1,2)
    

    I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.

    EDIT:
    I have now added fast-paths for:

    • skipping kv-chunking when kv_chunk_size >= k_tokens
      • this turns the algorithm into "attention slicing"
    • skipping q-chunking when q_chunk_size >= q_tokens
    • skipping all chunking when the kv_chunk_size >= k_tokens and q_chunk_size >= q_tokens
    • skipping all chunking when the [email protected] matmul requires fewer bytes than a user-provided threshold
    opened by Birch-san 1
Releases(0.1.3)
  • 0.1.2(Mar 7, 2022)

    What's Changed

    This update fixes torch device handling issues in code. GPU and other kinds of tensors can be used safely.

    • Update utils.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5
    • Update attention_torch.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/6

    New Contributors

    • @yhgon made their first contribution in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1.0...0.1.2

    Source code(tar.gz)
    Source code(zip)
  • 0.1.1.0(Feb 3, 2022)

    Added mask, bias calculation functions for custom and memory efficient chunks computation. So now sublinear memory computation mask, bias are possible.

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1...0.1.1.0

    Source code(tar.gz)
    Source code(zip)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
Free course that takes you from zero to Reinforcement Learning PRO 🦸🏻‍🦸🏽

The Hands-on Reinforcement Learning course 🚀 From zero to HERO 🦸🏻‍🦸🏽 Out of intense complexities, intense simplicities emerge. -- Winston Churchi

Pau Labarta Bajo 260 Dec 28, 2022
Implementation of "GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings" in PyTorch

PyGAS: Auto-Scaling GNNs in PyG PyGAS is the practical realization of our G NN A uto S cale (GAS) framework, which scales arbitrary message-passing GN

Matthias Fey 139 Dec 25, 2022
Using OpenAI's CLIP to upscale and enhance images

CLIP Upscaler and Enhancer Using OpenAI's CLIP to upscale and enhance images Based on nshepperd's JAX CLIP Guided Diffusion v2.4 Sample Results Viewpo

Tripp Lyons 5 Jun 14, 2022
SCAAML is a deep learning framwork dedicated to side-channel attacks run on top of TensorFlow 2.x.

SCAAML (Side Channel Attacks Assisted with Machine Learning) is a deep learning framwork dedicated to side-channel attacks. It is written in python and run on top of TensorFlow 2.x.

Google 69 Dec 21, 2022
Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences

Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences This repository is an official PyTorch implementation of Neighbor

DIVE Lab, Texas A&M University 8 Jun 12, 2022
KaziText is a tool for modelling common human errors.

KaziText KaziText is a tool for modelling common human errors. It estimates probabilities of individual error types (so called aspects) from grammatic

ÚFAL 3 Nov 24, 2022
Keras documentation, hosted live at keras.io

Keras.io documentation generator This repository hosts the code used to generate the keras.io website. Generating a local copy of the website pip inst

Keras 2k Jan 08, 2023
MegEngine implementation of YOLOX

Introduction YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and ind

旷视天元 MegEngine 77 Nov 22, 2022
Permeability Prediction Via Multi Scale 3D CNN

Permeability-Prediction-Via-Multi-Scale-3D-CNN Data: The raw CT rock cores are obtained from the Imperial Colloge portal. The CT rock cores are sub-sa

Mohamed Elmorsy 2 Jul 06, 2022
House-GAN++: Generative Adversarial Layout Refinement Network towards Intelligent Computational Agent for Professional Architects

House-GAN++ Code and instructions for our paper: House-GAN++: Generative Adversarial Layout Refinement Network towards Intelligent Computational Agent

122 Dec 28, 2022
《Where am I looking at? Joint Location and Orientation Estimation by Cross-View Matching》(CVPR 2020)

This contains the codes for cross-view geo-localization method described in: Where am I looking at? Joint Location and Orientation Estimation by Cross-View Matching, CVPR2020.

41 Oct 27, 2022
A non-linear, non-parametric Machine Learning method capable of modeling complex datasets

Fast Symbolic Regression Symbolic Regression is a non-linear, non-parametric Machine Learning method capable of modeling complex data sets. fastsr aim

VAMSHI CHOWDARY 3 Jun 22, 2022
Hierarchical Memory Matching Network for Video Object Segmentation (ICCV 2021)

Hierarchical Memory Matching Network for Video Object Segmentation Hongje Seong, Seoung Wug Oh, Joon-Young Lee, Seongwon Lee, Suhyeon Lee, Euntai Kim

Hongje Seong 72 Dec 14, 2022
Local trajectory planner based on a multilayer graph framework for autonomous race vehicles.

Graph-Based Local Trajectory Planner The graph-based local trajectory planner is python-based and comes with open interfaces as well as debug, visuali

TUM - Institute of Automotive Technology 160 Jan 04, 2023
TAUFE: Task-Agnostic Undesirable Feature DeactivationUsing Out-of-Distribution Data

A deep neural network (DNN) has achieved great success in many machine learning tasks by virtue of its high expressive power. However, its prediction can be easily biased to undesirable features, whi

KAIST Data Mining Lab 8 Dec 07, 2022
CURL: Contrastive Unsupervised Representations for Reinforcement Learning

CURL Rainbow Status: Archive (code is provided as-is, no updates expected) This is an implementation of CURL: Contrastive Unsupervised Representations

Aravind Srinivas 46 Dec 12, 2022
Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth [Paper]

Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth [Paper] Downloads [Downloads] Trained ckpt files for NYU Depth V2 and

98 Jan 01, 2023
Lenia - Mathematical Life Forms

For full version list, see Timeline in Lenia portal [2020-10-13] Update Python version with multi-kernel and multi-channel extensions (v3.4 LeniaNDK.p

Bert Chan 3.1k Dec 28, 2022
Unsupervised Image Generation with Infinite Generative Adversarial Networks

Unsupervised Image Generation with Infinite Generative Adversarial Networks Here is the implementation of MICGANs using DCGAN architecture on MNIST da

16 Dec 24, 2021
Signals-backend - A suite of card games written in Python

Card game A suite of card games written in the Python language. Features coming

1 Feb 15, 2022