Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Overview

Self-attention building blocks for computer vision applications in PyTorch

Implementation of self attention mechanisms for computer vision in PyTorch with einsum and einops. Focused on computer vision self-attention modules.

Install it via pip

It would be nice to install pytorch in your enviroment, in case you don't have a GPU.

pip install self-attention-cv

Related articles

More articles are on the way.

Code Examples

Multi-head attention

import torch
from self_attention_cv import MultiHeadSelfAttention

model = MultiHeadSelfAttention(dim=64)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x, mask)

Axial attention

import torch
from self_attention_cv import AxialAttentionBlock
model = AxialAttentionBlock(in_channels=256, dim=64, heads=8)
x = torch.rand(1, 256, 64, 64)  # [batch, tokens, dim, dim]
y = model(x)

Vanilla Transformer Encoder

import torch
from self_attention_cv import TransformerEncoder
model = TransformerEncoder(dim=64,blocks=6,heads=8)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x,mask)

Vision Transformer with/without ResNet50 backbone for image classification

import torch
from self_attention_cv import ViT, ResNet50ViT

model1 = ResNet50ViT(img_dim=128, pretrained_resnet=False, 
                        blocks=6, num_classes=10, 
                        dim_linear_block=256, dim=256)
# or
model2 = ViT(img_dim=256, in_channels=3, patch_dim=16, num_classes=10,dim=512)
x = torch.rand(2, 3, 256, 256)
y = model2(x) # [2,10]

A re-implementation of Unet with the Vision Transformer encoder

import torch
from self_attention_cv.transunet import TransUnet
a = torch.rand(2, 3, 128, 128)
model = TransUnet(in_channels=3, img_dim=128, vit_blocks=8,
vit_dim_linear_mhsa_block=512, classes=5)
y = model(a) # [2, 5, 128, 128]

Bottleneck Attention block

import torch
from self_attention_cv.bottleneck_transformer import BottleneckBlock
inp = torch.rand(1, 512, 32, 32)
bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4, out_channels=1024, pooling=True)
y = bottleneck_block(inp)

Position embeddings are also available

1D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import AbsPosEmb1D,RelPosEmb1D

model = AbsPosEmb1D(tokens=20, dim_head=64)
# batch heads tokens dim_head
q = torch.rand(2, 3, 20, 64)
y1 = model(q)

model = RelPosEmb1D(tokens=20, dim_head=64, heads=3)
q = torch.rand(2, 3, 20, 64)
y2 = model(q)

2D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import RelPosEmb2D
dim = 32  # spatial dim of the feat map
model = RelPosEmb2D(
    feat_map_size=(dim, dim),
    dim_head=128)

q = torch.rand(2, 4, dim*dim, 128)
y = model(q)

References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
  2. Wang, H., Zhu, Y., Green, B., Adam, H., Yuille, A., & Chen, L. C. (2020, August). Axial-deeplab: Stand-alone axial-attention for panoptic segmentation. In European Conference on Computer Vision (pp. 108-126). Springer, Cham.
  3. Srinivas, A., Lin, T. Y., Parmar, N., Shlens, J., Abbeel, P., & Vaswani, A. (2021). Bottleneck Transformers for Visual Recognition. arXiv preprint arXiv:2101.11605.
  4. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
Comments
  • Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that?I `Traceback (most recent call last): File "self-attention-cv/tests/test_TransUnet.py", line 14, in test_TransUnet() File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet y = model(a) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward y = self.project_patches_back(y) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward return F.linear(input, self.weight, self.bias) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear output = input.matmul(weight.t()) RuntimeError: mat1 dim 1 must match mat2 dim 0

    Process finished with exit code 1 ` Could you please help me solve it? Thank you.

    opened by yezhengjie 7
  • TransUNet - Why is the patch_dim set to 1?

    TransUNet - Why is the patch_dim set to 1?

    Hi,

    Can you please explain why is the patch_dim set to 1 in TransUNet class? Thank you in advance!

    https://github.com/The-AI-Summer/self-attention-cv/blob/8280009366b633921342db6cab08da17b46fdf1c/self_attention_cv/transunet/trans_unet.py#L54

    opened by dsitnik 7
  • Question: Sliding Window Module for Transformer3dSeg Object

    Question: Sliding Window Module for Transformer3dSeg Object

    I was wondering whether or not you've implemented an example using the network in a 3d medical segmentation task and/or use case? If this network only exports the center slice of a patch then we would need a wrapper function to iterate through all patches in an image to get the final prediction for the entire volume. From the original paper, I assume they choose 10 patches at random from an image during training, but it's not too clear how they pieced everything together during testing.

    Your thoughts on this would be greatly appreciated!

    See: https://github.com/The-AI-Summer/self-attention-cv/blob/33ddf020d2d9fb9c4a4a3b9938383dc9b7405d8c/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L10

    opened by jmarsil 5
  • ResNet + Pyramid Vision Transformer Version 2

    ResNet + Pyramid Vision Transformer Version 2

    Thank you for your work with a clear explanation. As you know, ViT doesn't work on small datasets and I am implementing ResNet34 with Pyramid Vision Transformer Version 2 to make it better. The architecture of ViT and PVT V2 is completely different. Could you provide me some help to implement it? please

    opened by khawar-islam 3
  • Request for Including UNETR

    Request for Including UNETR

    Thanks for great work ! I noticed nice implementation of this paper (https://arxiv.org/abs/2103.10504) here:

    https://github.com/tamasino52/UNETR/blob/main/unetr.py

    It would be great if this can also be included in your repo, since it comes with lots of other great features. So we can explore more.

    Thanks ~

    opened by Siyuan89 3
  • ImageNet Pretrained TimesFormer

    ImageNet Pretrained TimesFormer

    I see you have recently added the TimesFormer model to this repository. In the paper, they initialize their model weights from ImageNet pretrained weights of ViT. Does your implementation offer this too? Thanks!

    opened by RaivoKoot 3
  • Do the encoder modules incorporate positional encoding?

    Do the encoder modules incorporate positional encoding?

    I am wondering if I use say the LinformerEncoder if I have to add the position encoding or if that's already done? From the source files it doesn't seem to be there, but I'm not sure how to include the position encoding as they seem to need the query which isn't available when just passing data directly to the LinformerEncoder. I very well may be missing something any help would be great. Perhaps an example using positional encoding would be good.

    opened by jfkback 3
  • use AxialAttention on gpu

    use AxialAttention on gpu

    I try to use AxialAttention on gpu, but I get a mistake.Can you give me some tips about using AxialAttention on gpu. Thanks! mistake: RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0

    opened by Iverson-Al 2
  • Axial attention

    Axial attention

    What is the meaning of qkv_channels? https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/axial_attention_deeplab/axial_attention.py#L32

    opened by Jayden9912 1
  • Convolution-Free Medical Image Segmentation using Transformers

    Convolution-Free Medical Image Segmentation using Transformers

    Thank you very much for your contribution. As a novice, I have a doubt. In tranf3dseg, the output of the model is the prediction segmentation of the center patch, so how can I get the segmentation of the whole input image? I am looking forward to any reply.

    opened by WinsaW 1
  • Regression with attention

    Regression with attention

    Hello!

    thanks for sharing this nice repo :)

    I'm trying to use ViT to do regression on images. I'd like to predict 6 floats per image.

    My understanding is that I'd need to simply define the network as

    vit = ViT(img_dim=128,
                   in_channels=3,
                   patch_dim=16,
                   num_classes=6,
                   dim=512)
    

    and during training call

    vit(x)
    

    and compute the loss as MSE instead of CE.

    The network actually runs but it doesn't seem to converge. Is there something obvious I am missing?

    many thanks!

    opened by alemelis 1
  • Segmentation for full image

    Segmentation for full image

    Hi,

    Thank you for your effort and time in implementing this. I have a quick question, I want to get segmentation for full image not just for the middle token, would it be correct to change self.tokens to self.p here:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L66

    and change this:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L94

    to

    y = self.mlp_seg_head(y)

    opened by aqibsaeed 0
Releases(1.2.3)
Owner
AI Summer
Learn Deep Learning and Artificial Intelligence
AI Summer
we propose EfficientDerain for high-efficiency single-image deraining

EfficientDerain we propose EfficientDerain for high-efficiency single-image deraining Requirements python 3.6 pytorch 1.6.0 opencv-python 4.4.0.44 sci

Qing Guo 126 Dec 07, 2022
Train CPPNs as a Generative Model, using Generative Adversarial Networks and Variational Autoencoder techniques to produce high resolution images.

cppn-gan-vae tensorflow Train Compositional Pattern Producing Network as a Generative Model, using Generative Adversarial Networks and Variational Aut

hardmaru 343 Dec 29, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
This repo. is an implementation of ACFFNet, which is accepted for in Image and Vision Computing.

Attention-Guided-Contextual-Feature-Fusion-Network-for-Salient-Object-Detection This repo. is an implementation of ACFFNet, which is accepted for in I

5 Nov 21, 2022
NExT-QA: Next Phase of Question-Answering to Explaining Temporal Actions (CVPR2021)

NExT-QA We reproduce some SOTA VideoQA methods to provide benchmark results for our NExT-QA dataset accepted to CVPR2021 (with 1 'Strong Accept' and 2

Junbin Xiao 50 Nov 24, 2022
MINERVA: An out-of-the-box GUI tool for offline deep reinforcement learning

MINERVA is an out-of-the-box GUI tool for offline deep reinforcement learning, designed for everyone including non-programmers to do reinforcement learning as a tool.

Takuma Seno 80 Nov 06, 2022
Minimal implementation of Denoised Smoothing: A Provable Defense for Pretrained Classifiers in TensorFlow.

Denoised-Smoothing-TF Minimal implementation of Denoised Smoothing: A Provable Defense for Pretrained Classifiers in TensorFlow. Denoised Smoothing is

Sayak Paul 19 Dec 11, 2022
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

Xingyu Lin 93 Jan 05, 2023
The Few-Shot Bot: Prompt-Based Learning for Dialogue Systems

Few-Shot Bot: Prompt-Based Learning for Dialogue Systems This repository includes the dataset, experiments results, and code for the paper: Few-Shot B

Andrea Madotto 103 Dec 28, 2022
Deep learning for Engineers - Physics Informed Deep Learning

SciANN: Neural Networks for Scientific Computations SciANN is a Keras wrapper for scientific computations and physics-informed deep learning. New to S

SciANN 195 Jan 03, 2023
Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness

Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness This repository contains the code used for the exper

H.R. Oosterhuis 28 Nov 29, 2022
image scene graph generation benchmark

Scene Graph Benchmark in PyTorch 1.7 This project is based on maskrcnn-benchmark Highlights Upgrad to pytorch 1.7 Multi-GPU training and inference Bat

Microsoft 303 Dec 27, 2022
Data from "HateCheck: Functional Tests for Hate Speech Detection Models" (Röttger et al., ACL 2021)

In this repo, you can find the data from our ACL 2021 paper "HateCheck: Functional Tests for Hate Speech Detection Models". "test_suite_cases.csv" con

Paul Röttger 43 Nov 11, 2022
An easier way to build neural search on the cloud

An easier way to build neural search on the cloud Jina is a deep learning-powered search framework for building cross-/multi-modal search systems (e.g

Jina AI 17k Jan 02, 2023
Koç University deep learning framework.

Knet Knet (pronounced "kay-net") is the Koç University deep learning framework implemented in Julia by Deniz Yuret and collaborators. It supports GPU

1.4k Dec 31, 2022
The Fundamental Clustering Problems Suite (FCPS) summaries 54 state-of-the-art clustering algorithms, common cluster challenges and estimations of the number of clusters as well as the testing for cluster tendency.

FCPS Fundamental Clustering Problems Suite The package provides over sixty state-of-the-art clustering algorithms for unsupervised machine learning pu

9 Nov 27, 2022
Hierarchical probabilistic 3D U-Net, with attention mechanisms (—𝘈𝘵𝘵𝘦𝘯𝘵𝘪𝘰𝘯 𝘜-𝘕𝘦𝘵, 𝘚𝘌𝘙𝘦𝘴𝘕𝘦𝘵) and a nested decoder structure with deep supervision (—𝘜𝘕𝘦𝘵++).

Hierarchical probabilistic 3D U-Net, with attention mechanisms (—𝘈𝘵𝘵𝘦𝘯𝘵𝘪𝘰𝘯 𝘜-𝘕𝘦𝘵, 𝘚𝘌𝘙𝘦𝘴𝘕𝘦𝘵) and a nested decoder structure with deep supervision (—𝘜𝘕𝘦𝘵++). Built in TensorFlow 2.5. Configured for vox

Diagnostic Image Analysis Group 32 Dec 08, 2022
Official code for paper "ISNet: Costless and Implicit Image Segmentation for Deep Classifiers, with Application in COVID-19 Detection"

Official code for paper "ISNet: Costless and Implicit Image Segmentation for Deep Classifiers, with Application in COVID-19 Detection". LRPDenseNet.py

Pedro Ricardo Ariel Salvador Bassi 2 Sep 21, 2022
TianyuQi 10 Dec 11, 2022
Entity-Based Knowledge Conflicts in Question Answering.

Entity-Based Knowledge Conflicts in Question Answering Run Instructions | Paper | Citation | License This repository provides the Substitution Framewo

Apple 35 Oct 19, 2022