Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Overview

Memory Compressed Attention

Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.

The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.

Install

$ pip install memory_compressed_attention

Usage

import torch
from memory_compressed_attention import MemoryCompressedAttention

attn = MemoryCompressedAttention(
    dim = 512,
    heads = 8,                 # number of heads
    causal = False,            # auto-regressive or not
    compression_factor = 3,    # compression ratio
    dropout = 0.1              # dropout post-attention
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

attn(x, input_mask = mask) # (1, 1024, 512)

Citations

@misc{liu2018generating,
    title={Generating Wikipedia by Summarizing Long Sequences},
    author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
    year={2018},
    eprint={1801.10198},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
You might also like...
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting.

GAN Memory for Lifelong learning This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting. Please consider citing our paper

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Comments
  • The order of masking and softmax operation

    The order of masking and softmax operation

    Hi,

    In memory_compressed_attention.py, I'm wondering if we need to do softmax operation after masking? Btw, if the entry in the mask should be float('-inf') instead of -float('-inf')? If I make something wrong, please correct me.

    image

    Thanks!

    opened by cfeng16 3
  • mask error in attention

    mask error in attention

    Very grateful for your pioneering work! I want to use it in Standard Transformer released in http://nlp.seas.harvard.edu/2018/04/03/attention.html. but it mat a mask error in training. more detail information shown as follow, the code i use: image class ConvCompress(nn.Module): def init(self, dim, ratio = 2, groups = 1): super(ConvCompress, self).init() self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups) #self.linear = nn.Linear(dim, dim)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)
    

    class MemoryCompressedAttention(nn.Module): def init(self, h, d_model, compression_factor = 2, dropout = 0.1): super(MemoryCompressedAttention, self).init() assert (d_model % h) == 0, 'dimension must be divisible by number of heads' self.h = h self.d_model = d_model self.d_k = d_model // h

        self.compression_factor = compression_factor
        self.compress_fn = ConvCompress(d_model, compression_factor, groups = h)
    
        #self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.wq = nn.Linear(d_model, d_model, bias = False)
        self.wk = nn.Linear(d_model, d_model, bias = False)
        self.wv = nn.Linear(d_model, d_model, bias = False)
    
        self.wo = nn.Linear(d_model, d_model)
    
        self.dropout = nn.Dropout(dropout)
    
        #self.null_k = nn.Parameter(torch.zeros(1, 1, d_model))
        #self.null_v = nn.Parameter(torch.zeros(1, 1, d_model))
    
    def forward(self, query, key, value, mask = None):
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        t = query.size(1)
        cf = self.compression_factor
    
        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)
    
        # make sure keys and values sequence lengths
        # are divisible by the compression factor
        padding = cf - (t % cf)
        if padding != 0:
            key, value = map(lambda t: F.pad(t, (0, 0, padding, 0)), (key, value))
    
    
        # compress keys and values
        key, value = map(self.compress_fn, (key, value))
    
        # attach a null key and value, in the case that the first query has no keys to pay attention to
        null_k = nn.Parameter(torch.zeros(key.size(0), 1, self.d_model)).cuda()
        null_v = nn.Parameter(torch.zeros(value.size(0), 1, self.d_model)).cuda()
    
        key = torch.cat((null_k, key), dim=1)
        value = torch.cat((null_v, value), dim=1)
        
        # merge heads
        #query, key, value = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (query, key, value))
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query = query.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        key = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        value = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
    
      
        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
    
        # 3) "Concat" using a view and apply a final linear.   # split heads and combine
        x = x.contiguous().view(nbatches, -1, self.d_model)
        out = self.wo(x)
    
        return out
    

    The error was show that image

    I want to know how to fix it, and how to do mask for N*M matrix??

    opened by HN123-123 0
Releases(0.0.5)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffic Environments for IV 2022.

TorchGRL TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffi

XXQQ 42 Dec 09, 2022
Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".

Detecting Twenty-thousand Classes using Image-level Supervision Detic: A Detector with image classes that can use image-level labels to easily train d

Meta Research 1.3k Jan 04, 2023
A large-scale benchmark for co-optimizing the design and control of soft robots, as seen in NeurIPS 2021.

Evolution Gym A large-scale benchmark for co-optimizing the design and control of soft robots. As seen in Evolution Gym: A Large-Scale Benchmark for E

121 Dec 14, 2022
A Python wrapper for Google Tesseract

Python Tesseract Python-tesseract is an optical character recognition (OCR) tool for python. That is, it will recognize and "read" the text embedded i

Matthias A Lee 4.6k Jan 05, 2023
This is an official implementation of CvT: Introducing Convolutions to Vision Transformers.

Introduction This is an official implementation of CvT: Introducing Convolutions to Vision Transformers. We present a new architecture, named Convolut

Microsoft 408 Dec 30, 2022
An example to implement a new backbone with OpenMMLab framework.

Backbone example on OpenMMLab framework English | 简体中文 Introduction This is an template repo about how to use OpenMMLab framework to develop a new bac

Ma Zerun 22 Dec 29, 2022
Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads-Tutorial-3 Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads Inc 2 Jan 03, 2022
A PyTorch implementation of Implicit Q-Learning

IQL-PyTorch This repository houses a minimal PyTorch implementation of Implicit Q-Learning (IQL), an offline reinforcement learning algorithm, along w

Garrett Thomas 30 Dec 12, 2022
This repository contains answers of the Shopify Summer 2022 Data Science Intern Challenge.

Data-Science-Intern-Challenge This repository contains answers of the Shopify Summer 2022 Data Science Intern Challenge. Summer 2022 Data Science Inte

1 Jan 11, 2022
Apache Spark - A unified analytics engine for large-scale data processing

Apache Spark Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an op

The Apache Software Foundation 34.7k Jan 04, 2023
A lightweight library to compare different PyTorch implementations of the same network architecture.

TorchBug is a lightweight library designed to compare two PyTorch implementations of the same network architecture. It allows you to count, and compar

Arjun Krishnakumar 5 Jan 02, 2023
Several simple examples for popular neural network toolkits calling custom CUDA operators.

Neural Network CUDA Example Several simple examples for neural network toolkits (PyTorch, TensorFlow, etc.) calling custom CUDA operators. We provide

WeiYang 798 Jan 01, 2023
Akshat Surolia 2 May 11, 2022
SmoothGrad implementation in PyTorch

SmoothGrad implementation in PyTorch PyTorch implementation of SmoothGrad: removing noise by adding noise. Vanilla Gradients SmoothGrad Guided backpro

SSKH 143 Jan 05, 2023
SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement

SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement This repository implements the approach described in SporeAgent: Reinforced

Dominik Bauer 5 Jan 02, 2023
Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color Image (ICCV 2021)

Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color Image Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color

75 Dec 02, 2022
Download from Onlyfans.com.

OnlySave: Onlyfans downloader Getting Started: Download the setup executable from the latest release. Install and run. Only works on Windows currently

4 May 30, 2022
Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

Stanford Machine Learning Group 34 Nov 16, 2022
Retina blood vessel segmentation with a convolutional neural network

Retina blood vessel segmentation with a convolution neural network (U-net) This repository contains the implementation of a convolutional neural netwo

Orobix 1.2k Jan 06, 2023
Like a cowsay but without cows!

Foxsay This is a simple program that generates pictures of a cute fox with a message. It is like a cowsay but without cows! Fox girls are better! Usag

Anastasia Kim 28 Feb 20, 2022