Implementation of Nyström Self-attention, from the paper Nyströmformer

Overview

Nyström Attention

Implementation of Nyström Self-attention, from the paper Nyströmformer.

Yannic Kilcher video

Install

$ pip install nystrom-attention

Usage

import torch
from nystrom_attention import NystromAttention

attn = NystromAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    num_landmarks = 256,    # number of landmarks
    pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
    residual = True         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

attn(x, mask = mask) # (1, 16384, 512)

Nyströmformer, layers of Nyström attention

import torch
from nystrom_attention import Nystromformer

model = Nystromformer(
    dim = 512,
    dim_head = 64,
    heads = 8,
    depth = 6,
    num_landmarks = 256,
    pinv_iterations = 6
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

model(x, mask = mask) # (1, 16384, 512)

You can also import it as Nyströmer if you wish

from nystrom_attention import Nystromer

Citations

@misc{xiong2021nystromformer,
    title   = {Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention},
    author  = {Yunyang Xiong and Zhanpeng Zeng and Rudrasis Chakraborty and Mingxing Tan and Glenn Fung and Yin Li and Vikas Singh},
    year    = {2021},
    eprint  = {2102.03902},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Comments
  • Clarification on masking

    Clarification on masking

    Given the dimensionality of the mask argument, (N, T), I'm assuming this is a boolean mask for masking out padding tokens. I created the following function to generate such a mask given an input tensor:

    def _create_pad_mask(self, x: torch.LongTensor) -> torch.BoolTensor:
        mask = torch.ones_like(x).to(torch.bool)
        mask[x==0] = False
        return mask
    

    where 0 is the padding token, setting positions to False so not to attend to them.

    However, I am unsure how to apply a causal mask to the attention layers so to prevent my decoder from accessing future elements. I couldn't see an example of this in the full Nystromformer module. How can I achieve this?

    For context, I am trying to apply the causal mask generated by the following function:

    def _create_causal_mask(self, x: torch.LongTensor) -> torch.FloatTensor:
        size = x.shape[1]
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill_(mask == 0, float('-inf')).masked_fill_(mask==1, 0.0)
        return mask
    

    One way I can think of is to set return_attn to True, apply the mask on the returned attention weights then matmul with the value tensor. But this has a few issues:

    • Having to return v
    • Computing the full attention matrix (I think), defeating the entire point of linear attention
    • Needlessly calculating out only to discard it.

    Is this just a limitation of Nystrom attention? Or am I overlooking something obvious?

    Thanks

    opened by vvvm23 3
  • Possible bug with padding

    Possible bug with padding

    Hey there,

    I was going through the code and I noticed the following, which I found curious.

    In Line 75, you pad the input tensor to a multiple of num_landmarks from the front:

    x = F.pad(x, (0, 0, padding, 0), value = 0)
    

    In Line 144 you trim the extra padding elements you inserted in the output tensor from the end.

    out = out[:, :n]
    

    Am I not getting something, or should we be removing the front elements of out?

    out = out[:, out.size(1) - n:]
    
    opened by georgepar 2
  • Nystrom for Image processing

    Nystrom for Image processing

    thank you for sharing the wondeful code. I am working on image processing and wanted to try your code for the same. I have 2 doubts:

    1. How to select residual_conv_kernel? I could not find any details for the same. also, it is enabled by a flag. When should we enable it and when to disable it?
    2. Is there any guideline for deciding num_landmarks for image processing task?

    Thanks

    opened by paragon1234 1
  • Error when mask is of the same size as that of the input X

    Error when mask is of the same size as that of the input X

    Hi,

    First of all, thank you for putting such an easy to use implementation on GitHub. I'm trying to incorporate the nystrom attention into a legacy codebase, it previously used to provide the input X and the mask (off the same dimensions as X) to a Multi headed Attention Layer.

    When I'm trying to integrate nystrom attention with it, it runs alright without the mask. But, when I pass the mask alongside it, it throws einops rearrange error.

    Sorry, if this is a very basic question, but how would you recommend I deal with handling 3D mask (same dimensions as the size of input) in the codebase.

    Best, VB

    opened by Vaibhavs10 1
  • ViewBackward inplace deprecation warning

    ViewBackward inplace deprecation warning

    Hello again,

    The following code results in a UserWarning in PyTorch 1.8.1.

    In [1]: from nystrom_attention.nystrom_attention import NystromAttention
    
    In [2]: import torch
    
    In [3]: attn = NystromAttention(256)
    
    In [4]: x = torch.randn(1, 8192, 256)
    
    In [5]: attn(x)
    /home/alex/.tmp/nystrom-attention/nystrom_attention/nystrom_attention.py:91: UserWarning: Output 0 of ViewBackward is a view and is being modified inplace. This view is an output of a function that returns multiple views. Inplace operators on such views are being deprecated and will be forbidden starting from version 1.8. Consider using `unsafe_` version of the function that produced this view or don't modify this view inplace. (Triggered internally at  ../torch/csrc/autograd/variable.cpp:547.)
      q *= self.scale
    Out[5]:
    tensor([[[-0.0449, -0.1726,  0.1409,  ...,  0.0127,  0.2287, -0.2437],
             [-0.1132,  0.3229, -0.1279,  ...,  0.0084, -0.3307, -0.2351],
             [ 0.0361,  0.1013,  0.0828,  ...,  0.1045, -0.1627,  0.0736],
             ...,
             [ 0.0018,  0.1385, -0.1716,  ..., -0.0366, -0.0682,  0.0241],
             [ 0.1497,  0.0149, -0.0020,  ..., -0.0352, -0.1126,  0.0193],
             [ 0.1341,  0.0077,  0.1627,  ..., -0.0363,  0.1057, -0.2071]]],
           grad_fn=<SliceBackward>)
    

    Not a huge issue, but worth mentioning

    opened by vvvm23 1
  • Relative position encoding

    Relative position encoding

    Similar to the question raised for the performer architecture , is it possible to implement a relative position encoding given the methodology in which attention is calculated?

    opened by jdcla 1
  • How can we implement

    How can we implement "batch_first" in Nystrom attention?

    Hi,

    Thanks a lot for implementing the nystromformer attention algorithm! Very nice job!

    I am wondering whether it is feasible to add the "batch_first" option in the nystrom attention algorithm? This allow the algorithm to be integrated in the existing pytorch transformer encoder architecture.

    opened by mark0935git 0
  • x-transformers

    x-transformers

    Hi @lucidrains - just wondering if we can plug in Nystrom Attention with x-transformers?

    I've been plugging in Vision Transformers with X-transformers but am wondering if its possible to have a Nystrom transformer with x-transformer improvements to plug into a ViT?

    opened by robbohua 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Fudan Zhang Vision Group 897 Jan 05, 2023
Pyeventbus: a publish/subscribe event bus

pyeventbus pyeventbus is a publish/subscribe event bus for Python 2.7. simplifies the communication between python classes decouples event senders and

15 Apr 21, 2022
Hypercomplex Neural Networks with PyTorch

HyperNets Hypercomplex Neural Networks with PyTorch: this repository would be a container for hypercomplex neural network modules to facilitate resear

Eleonora Grassucci 21 Dec 27, 2022
One-Shot Neural Ensemble Architecture Search by Diversity-Guided Search Space Shrinking

One-Shot Neural Ensemble Architecture Search by Diversity-Guided Search Space Shrinking This is an official implementation for NEAS presented in CVPR

Multimedia Research 19 Sep 08, 2022
Baselines for TrajNet++

TrajNet++ : The Trajectory Forecasting Framework PyTorch implementation of Human Trajectory Forecasting in Crowds: A Deep Learning Perspective TrajNet

VITA lab at EPFL 183 Jan 05, 2023
Reference code for the paper CAMS: Color-Aware Multi-Style Transfer.

CAMS: Color-Aware Multi-Style Transfer Mahmoud Afifi1, Abdullah Abuolaim*1, Mostafa Hussien*2, Marcus A. Brubaker1, Michael S. Brown1 1York University

Mahmoud Afifi 36 Dec 04, 2022
Bayesian Optimization Library for Medical Image Segmentation.

bayesmedaug: Bayesian Optimization Library for Medical Image Segmentation. bayesmedaug optimizes your data augmentation hyperparameters for medical im

Şafak Bilici 7 Feb 10, 2022
Pytorch Implementation of LNSNet for Superpixel Segmentation

LNSNet Overview Official implementation of Learning the Superpixel in a Non-iterative and Lifelong Manner (CVPR'21) Learning Strategy The proposed LNS

42 Oct 11, 2022
Example how to deploy deep learning model with aiohttp.

aiohttp-demos Demos for aiohttp project. Contents Imagetagger Deep Learning Image Classifier URL shortener Toxic Comments Classifier Moderator Slack B

aio-libs 661 Jan 04, 2023
Face Mask Detection system based on computer vision and deep learning using OpenCV and Tensorflow/Keras

Face Mask Detection Face Mask Detection System built with OpenCV, Keras/TensorFlow using Deep Learning and Computer Vision concepts in order to detect

Chandrika Deb 1.4k Jan 03, 2023
A new benchmark for Icon Question Answering (IconQA) and a large-scale icon dataset Icon645.

IconQA About IconQA is a new diverse abstract visual question answering dataset that highlights the importance of abstract diagram understanding and c

Pan Lu 24 Dec 30, 2022
Unofficial pytorch implementation for Self-critical Sequence Training for Image Captioning. and others.

An Image Captioning codebase This is a codebase for image captioning research. It supports: Self critical training from Self-critical Sequence Trainin

Ruotian(RT) Luo 906 Jan 03, 2023
Tool for live presentations using manim

manim-presentation Tool for live presentations using manim Install pip install manim-presentation opencv-python Usage Use the class Slide as your sce

Federico Galatolo 146 Jan 06, 2023
Script utilizando OpenCV e modelo Machine Learning para detectar o uso de máscaras.

Reconhecendo máscaras Este repositório contém um script em Python3 que reconhece se um rosto está ou não portando uma máscara! O código utiliza da bib

Maria Eduarda de Azevedo Silva 168 Oct 20, 2022
Real-time VIBE: Frame by Frame Inference of VIBE (Video Inference for Human Body Pose and Shape Estimation)

Real-time VIBE Inference VIBE frame-by-frame. Overview This is a frame-by-frame inference fork of VIBE at [https://github.com/mkocabas/VIBE]. Usage: i

23 Jul 02, 2022
PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

StarEnhancer StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral) Abstract: Image enhancement is a subjective process w

IDKiro 133 Dec 28, 2022
Python package provinding tools for artistic interactive applications using AI

Documentation redrawing Python package provinding tools for artistic interactive applications using AI Created by ReDrawing Campinas team for the Open

ReDrawing Campinas 1 Sep 30, 2021
Segmentation in Style: Unsupervised Semantic Image Segmentation with Stylegan and CLIP

Segmentation in Style: Unsupervised Semantic Image Segmentation with Stylegan and CLIP Abstract: We introduce a method that allows to automatically se

Daniil Pakhomov 134 Dec 19, 2022
This repo is about implementing different approaches of pose estimation and also is a sub-task of the smart hospital bed project :smile:

Pose-Estimation This repo is a sub-task of the smart hospital bed project which is about implementing the task of pose estimation 😄 Many thanks to th

Max 11 Oct 17, 2022
Using LSTM to detect spoofing attacks in an Air-Ground network

Using LSTM to detect spoofing attacks in an Air-Ground network Specifications IDE: Spider Packages: Tensorflow 2.1.0 Keras NumPy Scikit-learn Matplotl

Tiep M. H. 1 Nov 20, 2021