Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Overview

Self-attention building blocks for computer vision applications in PyTorch

Implementation of self attention mechanisms for computer vision in PyTorch with einsum and einops. Focused on computer vision self-attention modules.

Install it via pip

It would be nice to install pytorch in your enviroment, in case you don't have a GPU.

pip install self-attention-cv

Related articles

More articles are on the way.

Code Examples

Multi-head attention

import torch
from self_attention_cv import MultiHeadSelfAttention

model = MultiHeadSelfAttention(dim=64)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x, mask)

Axial attention

import torch
from self_attention_cv import AxialAttentionBlock
model = AxialAttentionBlock(in_channels=256, dim=64, heads=8)
x = torch.rand(1, 256, 64, 64)  # [batch, tokens, dim, dim]
y = model(x)

Vanilla Transformer Encoder

import torch
from self_attention_cv import TransformerEncoder
model = TransformerEncoder(dim=64,blocks=6,heads=8)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x,mask)

Vision Transformer with/without ResNet50 backbone for image classification

import torch
from self_attention_cv import ViT, ResNet50ViT

model1 = ResNet50ViT(img_dim=128, pretrained_resnet=False, 
                        blocks=6, num_classes=10, 
                        dim_linear_block=256, dim=256)
# or
model2 = ViT(img_dim=256, in_channels=3, patch_dim=16, num_classes=10,dim=512)
x = torch.rand(2, 3, 256, 256)
y = model2(x) # [2,10]

A re-implementation of Unet with the Vision Transformer encoder

import torch
from self_attention_cv.transunet import TransUnet
a = torch.rand(2, 3, 128, 128)
model = TransUnet(in_channels=3, img_dim=128, vit_blocks=8,
vit_dim_linear_mhsa_block=512, classes=5)
y = model(a) # [2, 5, 128, 128]

Bottleneck Attention block

import torch
from self_attention_cv.bottleneck_transformer import BottleneckBlock
inp = torch.rand(1, 512, 32, 32)
bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4, out_channels=1024, pooling=True)
y = bottleneck_block(inp)

Position embeddings are also available

1D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import AbsPosEmb1D,RelPosEmb1D

model = AbsPosEmb1D(tokens=20, dim_head=64)
# batch heads tokens dim_head
q = torch.rand(2, 3, 20, 64)
y1 = model(q)

model = RelPosEmb1D(tokens=20, dim_head=64, heads=3)
q = torch.rand(2, 3, 20, 64)
y2 = model(q)

2D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import RelPosEmb2D
dim = 32  # spatial dim of the feat map
model = RelPosEmb2D(
    feat_map_size=(dim, dim),
    dim_head=128)

q = torch.rand(2, 4, dim*dim, 128)
y = model(q)

References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
  2. Wang, H., Zhu, Y., Green, B., Adam, H., Yuille, A., & Chen, L. C. (2020, August). Axial-deeplab: Stand-alone axial-attention for panoptic segmentation. In European Conference on Computer Vision (pp. 108-126). Springer, Cham.
  3. Srinivas, A., Lin, T. Y., Parmar, N., Shlens, J., Abbeel, P., & Vaswani, A. (2021). Bottleneck Transformers for Visual Recognition. arXiv preprint arXiv:2101.11605.
  4. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
Comments
  • Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that?I `Traceback (most recent call last): File "self-attention-cv/tests/test_TransUnet.py", line 14, in test_TransUnet() File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet y = model(a) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward y = self.project_patches_back(y) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward return F.linear(input, self.weight, self.bias) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear output = input.matmul(weight.t()) RuntimeError: mat1 dim 1 must match mat2 dim 0

    Process finished with exit code 1 ` Could you please help me solve it? Thank you.

    opened by yezhengjie 7
  • TransUNet - Why is the patch_dim set to 1?

    TransUNet - Why is the patch_dim set to 1?

    Hi,

    Can you please explain why is the patch_dim set to 1 in TransUNet class? Thank you in advance!

    https://github.com/The-AI-Summer/self-attention-cv/blob/8280009366b633921342db6cab08da17b46fdf1c/self_attention_cv/transunet/trans_unet.py#L54

    opened by dsitnik 7
  • Question: Sliding Window Module for Transformer3dSeg Object

    Question: Sliding Window Module for Transformer3dSeg Object

    I was wondering whether or not you've implemented an example using the network in a 3d medical segmentation task and/or use case? If this network only exports the center slice of a patch then we would need a wrapper function to iterate through all patches in an image to get the final prediction for the entire volume. From the original paper, I assume they choose 10 patches at random from an image during training, but it's not too clear how they pieced everything together during testing.

    Your thoughts on this would be greatly appreciated!

    See: https://github.com/The-AI-Summer/self-attention-cv/blob/33ddf020d2d9fb9c4a4a3b9938383dc9b7405d8c/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L10

    opened by jmarsil 5
  • ResNet + Pyramid Vision Transformer Version 2

    ResNet + Pyramid Vision Transformer Version 2

    Thank you for your work with a clear explanation. As you know, ViT doesn't work on small datasets and I am implementing ResNet34 with Pyramid Vision Transformer Version 2 to make it better. The architecture of ViT and PVT V2 is completely different. Could you provide me some help to implement it? please

    opened by khawar-islam 3
  • Request for Including UNETR

    Request for Including UNETR

    Thanks for great work ! I noticed nice implementation of this paper (https://arxiv.org/abs/2103.10504) here:

    https://github.com/tamasino52/UNETR/blob/main/unetr.py

    It would be great if this can also be included in your repo, since it comes with lots of other great features. So we can explore more.

    Thanks ~

    opened by Siyuan89 3
  • ImageNet Pretrained TimesFormer

    ImageNet Pretrained TimesFormer

    I see you have recently added the TimesFormer model to this repository. In the paper, they initialize their model weights from ImageNet pretrained weights of ViT. Does your implementation offer this too? Thanks!

    opened by RaivoKoot 3
  • Do the encoder modules incorporate positional encoding?

    Do the encoder modules incorporate positional encoding?

    I am wondering if I use say the LinformerEncoder if I have to add the position encoding or if that's already done? From the source files it doesn't seem to be there, but I'm not sure how to include the position encoding as they seem to need the query which isn't available when just passing data directly to the LinformerEncoder. I very well may be missing something any help would be great. Perhaps an example using positional encoding would be good.

    opened by jfkback 3
  • use AxialAttention on gpu

    use AxialAttention on gpu

    I try to use AxialAttention on gpu, but I get a mistake.Can you give me some tips about using AxialAttention on gpu. Thanks! mistake: RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0

    opened by Iverson-Al 2
  • Axial attention

    Axial attention

    What is the meaning of qkv_channels? https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/axial_attention_deeplab/axial_attention.py#L32

    opened by Jayden9912 1
  • Convolution-Free Medical Image Segmentation using Transformers

    Convolution-Free Medical Image Segmentation using Transformers

    Thank you very much for your contribution. As a novice, I have a doubt. In tranf3dseg, the output of the model is the prediction segmentation of the center patch, so how can I get the segmentation of the whole input image? I am looking forward to any reply.

    opened by WinsaW 1
  • Regression with attention

    Regression with attention

    Hello!

    thanks for sharing this nice repo :)

    I'm trying to use ViT to do regression on images. I'd like to predict 6 floats per image.

    My understanding is that I'd need to simply define the network as

    vit = ViT(img_dim=128,
                   in_channels=3,
                   patch_dim=16,
                   num_classes=6,
                   dim=512)
    

    and during training call

    vit(x)
    

    and compute the loss as MSE instead of CE.

    The network actually runs but it doesn't seem to converge. Is there something obvious I am missing?

    many thanks!

    opened by alemelis 1
  • Segmentation for full image

    Segmentation for full image

    Hi,

    Thank you for your effort and time in implementing this. I have a quick question, I want to get segmentation for full image not just for the middle token, would it be correct to change self.tokens to self.p here:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L66

    and change this:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L94

    to

    y = self.mlp_seg_head(y)

    opened by aqibsaeed 0
Releases(1.2.3)
Owner
AI Summer
Learn Deep Learning and Artificial Intelligence
AI Summer
RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation

RIFE RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation Ported from https://github.com/hzwer/arXiv2020-RIFE Dependencies NumPy

49 Jan 07, 2023
😇A pyTorch implementation of the DeepMoji model: state-of-the-art deep learning model for analyzing sentiment, emotion, sarcasm etc

------ Update September 2018 ------ It's been a year since TorchMoji and DeepMoji were released. We're trying to understand how it's being used such t

Hugging Face 865 Dec 24, 2022
Introduction to Statistics and Basics of Mathematics for Data Science - The Hacker's Way

HackerMath for Machine Learning “Study hard what interests you the most in the most undisciplined, irreverent and original manner possible.” ― Richard

Amit Kapoor 1.4k Dec 22, 2022
Density-aware Single Image De-raining using a Multi-stream Dense Network (CVPR 2018)

DID-MDN Density-aware Single Image De-raining using a Multi-stream Dense Network He Zhang, Vishal M. Patel [Paper Link] (CVPR'18) We present a novel d

He Zhang 224 Dec 12, 2022
3D dataset of humans Manipulating Objects in-the-Wild (MOW)

MOW dataset [Website] This repository maintains our 3D dataset of humans Manipulating Objects in-the-Wild (MOW). The dataset contains 512 images in th

Zhe Cao 28 Nov 06, 2022
Official implementation of the paper ``Unifying Nonlocal Blocks for Neural Networks'' (ICCV'21)

Spectral Nonlocal Block Overview Official implementation of the paper: Unifying Nonlocal Blocks for Neural Networks (ICCV'21) Spectral View of Nonloca

91 Dec 14, 2022
Decensoring Hentai with Deep Neural Networks. Formerly named DeepMindBreak.

DeepCreamPy Decensoring Hentai with Deep Neural Networks. Formerly named DeepMindBreak. A deep learning-based tool to automatically replace censored a

616 Jan 06, 2023
Adaout is a practical and flexible regularization method with high generalization and interpretability

Adaout Adaout is a practical and flexible regularization method with high generalization and interpretability. Requirements python 3.6 (Anaconda versi

lambett 1 Feb 09, 2022
Learning Dense Representations of Phrases at Scale (Lee et al., 2020)

DensePhrases DensePhrases provides answers to your natural language questions from the entire Wikipedia in real-time. While it efficiently searches th

Princeton Natural Language Processing 540 Dec 30, 2022
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Bayes-Newton Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Wil

AaltoML 165 Nov 27, 2022
MEDS: Enhancing Memory Error Detection for Large-Scale Applications

MEDS: Enhancing Memory Error Detection for Large-Scale Applications Prerequisites cmake and clang Build MEDS supporting compiler $ make Build Using Do

Secomp Lab at Purdue University 34 Dec 14, 2022
Processed, version controlled history of Minecraft's generated data and assets

mcmeta Processed, version controlled history of Minecraft's generated data and assets Repository structure Each of the following branches has a commit

Misode 75 Dec 28, 2022
Technical Analysis Indicators - Pandas TA is an easy to use Python 3 Pandas Extension with 130+ Indicators

Pandas TA - A Technical Analysis Library in Python 3 Pandas Technical Analysis (Pandas TA) is an easy to use library that leverages the Pandas package

Kevin Johnson 3.2k Jan 09, 2023
Official Pytorch implementation for video neural representation (NeRV)

NeRV: Neural Representations for Videos (NeurIPS 2021) Project Page | Paper | UVG Data Hao Chen, Bo He, Hanyu Wang, Yixuan Ren, Ser-Nam Lim, Abhinav S

hao 214 Dec 28, 2022
FIRM-AFL is the first high-throughput greybox fuzzer for IoT firmware.

FIRM-AFL FIRM-AFL is the first high-throughput greybox fuzzer for IoT firmware. FIRM-AFL addresses two fundamental problems in IoT fuzzing. First, it

356 Dec 23, 2022
Addon and nodes for working with structural biology and molecular data in Blender.

Molecular Nodes 🧬 🔬 💻 Buy Me a Coffee to Keep Development Going! Join a Community of Blender SciVis People! What is Molecular Nodes? Molecular Node

Brady Johnston 456 Jan 08, 2023
The official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness.

This repository is the official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness. Requirements pip install -r requi

Jie Ren 17 Dec 12, 2022
Title: Heart-Failure-Classification

This Notebook is based off an open source dataset available on where I have created models to classify patients who can potentially witness heart failure on the basis of various parameters. The best

Akarsh Singh 2 Sep 13, 2022
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch is a prototype of JAX-like composable function transforms for PyTorch.

Facebook Research 1.2k Jan 09, 2023
DP-CL(Continual Learning with Differential Privacy)

DP-CL(Continual Learning with Differential Privacy) This is the official implementation of the Continual Learning with Differential Privacy. If you us

Phung Lai 3 Nov 04, 2022