Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Overview

Memory Efficient Attention

arXiv PyPI version

This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch.

Implementation is almost same as the one proposed in the paper, with additional masking and adding bias compatibility, batch dimensions support and PyTorch implementation. For computing attention, the proposed method requires only O(sqrt(n)) memory, and the provided functions can be used as a drop-in replacement for attention calculation.

Important Note: This implementation is a trade-off between memory requirements and runtime, so you should adjust key_chunk_size and query_chunk_size parameters to achieve the best configuration for your usecase. Here is a note from the paper's authors:

While a constant chunk size for the queries and a chunk size of sqrt(n) for the keys and values is optimal for memory consumption, the runtime is also affected by the choice of chunk size in practice, which is heavily affected by the choice of hardware. Ultimately, we have to leave this trade-off to the programmer, and expose the chunk sizes as arguments query_chunk_size and key_chunk_size. In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact (on TPUv2), while still providing significant memory savings.

Quick Start

  1. Install the library
# for Jax
pip install memory-efficient-attention[jax]
# for PyTorch
pip install memory-efficient-attention[torch]
# for Running Tests
pip install memory-efficient-attention[testing]
  1. Compute attention with the proper function
0.5 bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100 # Adjust chunk sizes efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)">
import numpy as np
# for PyTorch
from memory_efficient_attention import efficient_dot_product_attention_pt
# or for Jax
from memory_efficient_attention import efficient_dot_product_attention_jax

# Random Data (batch dimensions are not necessary)
b = 8
query = np.random.rand(1, b, 128, 16, 8).astype("float32")
key = np.random.rand(1, b, 128, 16, 8).astype("float32")
value = np.random.rand(1, b, 128, 16, 8).astype("float32")
# optional, for casual tasks, ...
mask = np.random.rand(1, b, 16, 128, 128) > 0.5
bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100

# Adjust chunk sizes        
efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)

Citation

Please cite if this implementation helps your research. You can use the following BibTeX entry:

@misc{memory_efficient_attention,
	title = {Memory Efficient Attention},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/memory-efficient-attention}},
	year = {2021}
}

Also, for the paper:

@misc{rabe2021selfattention,
      title={Self-attention Does Not Need $O(n^2)$ Memory}, 
      author={Markus N. Rabe and Charles Staats},
      year={2021},
      eprint={2112.05682},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
You might also like...
Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory"

memory_efficient_attention.pytorch A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory" (Rabe&Staats'21). def effic

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Ă–zdemir Cetin & Heinz Koeppl | Pr

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

Implementation of
Implementation of "Efficient Regional Memory Network for Video Object Segmentation" (Xie et al., CVPR 2021).

RMNet This repository contains the source code for the paper Efficient Regional Memory Network for Video Object Segmentation. Cite this work @inprocee

Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation
Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation

STCN Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [a

Comments
  • feat: output_attentions

    feat: output_attentions

    I'm looking into hacking some of the models in the transformers library to use this library for attention, and I don't see a way to support output_attentions yet. This is a flag passed in transformers, where the attention weights are preserved and returned to the user, if it is set.

    I looked a little at implementing this in the torch backend, and I note the scan() function provides for only a single tensor return value. It seems to me that scan() function would be most clearly replaced by a for loop, but it could also be modified to handle tuples, or return_weights could be handled via accessing nonlocal data in some way instead of returning them through the chunk scanner. I'm also not sure how the output would best be passed to the user.

    Edit: Draft implementation 01/28 at https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:faba6371ac7faaa2040a2c26e15ae7ab87f94ce4 . I ended up extending the scan function for parity between implementations. Edit 2: Turns out it's the postsoftmax attention weights, not the presoftmax attention weights. I've updated this post and the draft implementation for this output: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:return_weights

    opened by xloem 4
  • Provide a flag for the user to receive attention weights

    Provide a flag for the user to receive attention weights

    This is my draft code for #1. I saw this feature in the transformers library and wanted to implement it here.

    I'm curious what you think about this feature and implementation.

    The code is simply slightly instrumented so that the final attention weights can be returned to the user. Tests are augmented to test this use. In utils, the scan function is expanded to handle tuples.

    A change to dynamic_slice crept in from dev, to use slices rather than index_slice. I've retained this change because it looks like it would execute faster to me, but it can be removed.

    Rebased and squashed from 84724e1de4721ea0333d6bdbb91e8bce74fbeac .

    opened by xloem 2
  • Improve performance via batched-matmul and fused multiplies

    Improve performance via batched-matmul and fused multiplies

    Many thanks for providing this reference implementation.

    I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
    https://github.com/Birch-san/diffusers/pull/1/commits/04372140a25d7f53549175f1f196599c3e9bf3a5

    Knowing that computing attention via baddbmm()+bmm() can outperform einsum by 18%: I tried to rewrite the algorithm to use those.

    I compared the speed of my optimized version, against the implementation in this repository. this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).

    compared to the implementation in this repository:
    my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).

    here's my optimized implementation:
    https://github.com/Birch-san/diffusers/pull/1

    batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head].

    code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head] format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head] format favoured by my implementation like so:

    query = query.transpose(1,2).flatten(end_dim=1)
    key = key.transpose(1,2).flatten(end_dim=1)
    value = value.transpose(1,2).flatten(end_dim=1)
    

    the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head] format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head] format like so:

    result.unflatten(0, (-1, attn.heads)).transpose(1,2)
    

    I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.

    EDIT:
    I have now added fast-paths for:

    • skipping kv-chunking when kv_chunk_size >= k_tokens
      • this turns the algorithm into "attention slicing"
    • skipping q-chunking when q_chunk_size >= q_tokens
    • skipping all chunking when the kv_chunk_size >= k_tokens and q_chunk_size >= q_tokens
    • skipping all chunking when the [email protected] matmul requires fewer bytes than a user-provided threshold
    opened by Birch-san 1
Releases(0.1.3)
  • 0.1.2(Mar 7, 2022)

    What's Changed

    This update fixes torch device handling issues in code. GPU and other kinds of tensors can be used safely.

    • Update utils.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5
    • Update attention_torch.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/6

    New Contributors

    • @yhgon made their first contribution in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1.0...0.1.2

    Source code(tar.gz)
    Source code(zip)
  • 0.1.1.0(Feb 3, 2022)

    Added mask, bias calculation functions for custom and memory efficient chunks computation. So now sublinear memory computation mask, bias are possible.

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1...0.1.1.0

    Source code(tar.gz)
    Source code(zip)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
Distance Encoding for GNN Design

Distance-encoding for GNN design This repository is the official PyTorch implementation of the DEGNN and DEAGNN framework reported in the paper: Dista

172 Nov 08, 2022
Style transfer, deep learning, feature transform

FastPhotoStyle License Copyright (C) 2018 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons

NVIDIA Corporation 10.9k Jan 02, 2023
v objective diffusion inference code for JAX.

v-diffusion-jax v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The models

Katherine Crowson 186 Dec 21, 2022
Code for the paper A Theoretical Analysis of the Repetition Problem in Text Generation

A Theoretical Analysis of the Repetition Problem in Text Generation This repository share the code for the paper "A Theoretical Analysis of the Repeti

Zihao Fu 37 Nov 21, 2022
Repo for flood prediction using LSTMs and HAND

Abstract Every year, floods cause billions of dollars’ worth of damages to life, crops, and property. With a proper early flood warning system in plac

1 Oct 27, 2021
Neural Tangent Generalization Attacks (NTGA)

Neural Tangent Generalization Attacks (NTGA) ICML 2021 Video | Paper | Quickstart | Results | Unlearnable Datasets | Competitions | Citation Overview

Chia-Hung Yuan 34 Nov 25, 2022
SCNet: Learning Semantic Correspondence

SCNet Code Region matching code is contributed by Kai Han ([email protected]). Dense

Kai Han 34 Sep 06, 2022
Portfolio analytics for quants, written in Python

QuantStats: Portfolio analytics for quants QuantStats Python library that performs portfolio profiling, allowing quants and portfolio managers to unde

Ran Aroussi 2.7k Jan 08, 2023
This script scrapes and stores the availability of timeslots for Car Driving Test at all RTA Serivce NSW centres in the state.

This script scrapes and stores the availability of timeslots for Car Driving Test at all RTA Serivce NSW centres in the state. Dependencies Account wi

Balamurugan Soundararaj 21 Dec 14, 2022
Deep deconfounded recommender (Deep-Deconf) for paper "Deep causal reasoning for recommendations"

Deep Causal Reasoning for Recommender Systems The codes are associated with the following paper: Deep Causal Reasoning for Recommendations, Yaochen Zh

Yaochen Zhu 22 Oct 15, 2022
AoT is a system for automatically generating off-target test harness by using build information.

AoT: Auto off-Target Automatically generating off-target test harness by using build information. Brought to you by the Mobile Security Team at Samsun

Samsung 10 Oct 19, 2022
magiCARP: Contrastive Authoring+Reviewing Pretraining

magiCARP: Contrastive Authoring+Reviewing Pretraining Welcome to the magiCARP API, the test bed used by EleutherAI for performing text/text bi-encoder

EleutherAI 43 Dec 29, 2022
A custom DeepStack model for detecting 16 human actions.

DeepStack_ActionNET This repository provides a custom DeepStack model that has been trained and can be used for creating a new object detection API fo

MOSES OLAFENWA 16 Nov 11, 2022
code associated with ACL 2021 DExperts paper

DExperts Hi! This repository contains code for the paper DExperts: Decoding-Time Controlled Text Generation with Experts and Anti-Experts to appear at

Alisa Liu 68 Dec 15, 2022
Real Time Object Detection and Classification using Yolo Algorithm.

Real time Object detection & Classification using YOLO algorithm. Real Time Object Detection and Classification using Yolo Algorithm. What is Object D

Ketan Chawla 1 Apr 17, 2022
Pixel-level Crack Detection From Images Of Levee Systems : A Comparative Study

PIXEL-LEVEL CRACK DETECTION FROM IMAGES OF LEVEE SYSTEMS : A COMPARATIVE STUDY G

Manisha Panta 2 Jul 23, 2022
Learning Lightweight Low-Light Enhancement Network using Pseudo Well-Exposed Images

Learning Lightweight Low-Light Enhancement Network using Pseudo Well-Exposed Images This repository contains the implementation of the following paper

Seonggwan Ko 9 Jul 30, 2022
Reinforcement Learning Theory Book (rus)

Reinforcement Learning Theory Book (rus)

qbrick 206 Nov 27, 2022
AnimationKit: AI Upscaling & Interpolation using Real-ESRGAN+RIFE

ALPHA 2.5: Frostbite Revival (Released 12/23/21) Changelog: [ UI ] Chained design. All steps link to one another! Use the master override toggles to s

87 Nov 16, 2022
Tom-the-AI - A compound artificial intelligence software for Linux systems.

Tom the AI (version 0.82) WARNING: This software is not yet ready to use, I'm still setting up the GitHub repository. Should be ready in a few days. T

2 Apr 28, 2022