Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Overview

Self-attention building blocks for computer vision applications in PyTorch

Implementation of self attention mechanisms for computer vision in PyTorch with einsum and einops. Focused on computer vision self-attention modules.

Install it via pip

It would be nice to install pytorch in your enviroment, in case you don't have a GPU.

pip install self-attention-cv

Related articles

More articles are on the way.

Code Examples

Multi-head attention

import torch
from self_attention_cv import MultiHeadSelfAttention

model = MultiHeadSelfAttention(dim=64)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x, mask)

Axial attention

import torch
from self_attention_cv import AxialAttentionBlock
model = AxialAttentionBlock(in_channels=256, dim=64, heads=8)
x = torch.rand(1, 256, 64, 64)  # [batch, tokens, dim, dim]
y = model(x)

Vanilla Transformer Encoder

import torch
from self_attention_cv import TransformerEncoder
model = TransformerEncoder(dim=64,blocks=6,heads=8)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x,mask)

Vision Transformer with/without ResNet50 backbone for image classification

import torch
from self_attention_cv import ViT, ResNet50ViT

model1 = ResNet50ViT(img_dim=128, pretrained_resnet=False, 
                        blocks=6, num_classes=10, 
                        dim_linear_block=256, dim=256)
# or
model2 = ViT(img_dim=256, in_channels=3, patch_dim=16, num_classes=10,dim=512)
x = torch.rand(2, 3, 256, 256)
y = model2(x) # [2,10]

A re-implementation of Unet with the Vision Transformer encoder

import torch
from self_attention_cv.transunet import TransUnet
a = torch.rand(2, 3, 128, 128)
model = TransUnet(in_channels=3, img_dim=128, vit_blocks=8,
vit_dim_linear_mhsa_block=512, classes=5)
y = model(a) # [2, 5, 128, 128]

Bottleneck Attention block

import torch
from self_attention_cv.bottleneck_transformer import BottleneckBlock
inp = torch.rand(1, 512, 32, 32)
bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4, out_channels=1024, pooling=True)
y = bottleneck_block(inp)

Position embeddings are also available

1D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import AbsPosEmb1D,RelPosEmb1D

model = AbsPosEmb1D(tokens=20, dim_head=64)
# batch heads tokens dim_head
q = torch.rand(2, 3, 20, 64)
y1 = model(q)

model = RelPosEmb1D(tokens=20, dim_head=64, heads=3)
q = torch.rand(2, 3, 20, 64)
y2 = model(q)

2D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import RelPosEmb2D
dim = 32  # spatial dim of the feat map
model = RelPosEmb2D(
    feat_map_size=(dim, dim),
    dim_head=128)

q = torch.rand(2, 4, dim*dim, 128)
y = model(q)

References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
  2. Wang, H., Zhu, Y., Green, B., Adam, H., Yuille, A., & Chen, L. C. (2020, August). Axial-deeplab: Stand-alone axial-attention for panoptic segmentation. In European Conference on Computer Vision (pp. 108-126). Springer, Cham.
  3. Srinivas, A., Lin, T. Y., Parmar, N., Shlens, J., Abbeel, P., & Vaswani, A. (2021). Bottleneck Transformers for Visual Recognition. arXiv preprint arXiv:2101.11605.
  4. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
Comments
  • Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that?I `Traceback (most recent call last): File "self-attention-cv/tests/test_TransUnet.py", line 14, in test_TransUnet() File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet y = model(a) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward y = self.project_patches_back(y) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward return F.linear(input, self.weight, self.bias) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear output = input.matmul(weight.t()) RuntimeError: mat1 dim 1 must match mat2 dim 0

    Process finished with exit code 1 ` Could you please help me solve it? Thank you.

    opened by yezhengjie 7
  • TransUNet - Why is the patch_dim set to 1?

    TransUNet - Why is the patch_dim set to 1?

    Hi,

    Can you please explain why is the patch_dim set to 1 in TransUNet class? Thank you in advance!

    https://github.com/The-AI-Summer/self-attention-cv/blob/8280009366b633921342db6cab08da17b46fdf1c/self_attention_cv/transunet/trans_unet.py#L54

    opened by dsitnik 7
  • Question: Sliding Window Module for Transformer3dSeg Object

    Question: Sliding Window Module for Transformer3dSeg Object

    I was wondering whether or not you've implemented an example using the network in a 3d medical segmentation task and/or use case? If this network only exports the center slice of a patch then we would need a wrapper function to iterate through all patches in an image to get the final prediction for the entire volume. From the original paper, I assume they choose 10 patches at random from an image during training, but it's not too clear how they pieced everything together during testing.

    Your thoughts on this would be greatly appreciated!

    See: https://github.com/The-AI-Summer/self-attention-cv/blob/33ddf020d2d9fb9c4a4a3b9938383dc9b7405d8c/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L10

    opened by jmarsil 5
  • ResNet + Pyramid Vision Transformer Version 2

    ResNet + Pyramid Vision Transformer Version 2

    Thank you for your work with a clear explanation. As you know, ViT doesn't work on small datasets and I am implementing ResNet34 with Pyramid Vision Transformer Version 2 to make it better. The architecture of ViT and PVT V2 is completely different. Could you provide me some help to implement it? please

    opened by khawar-islam 3
  • Request for Including UNETR

    Request for Including UNETR

    Thanks for great work ! I noticed nice implementation of this paper (https://arxiv.org/abs/2103.10504) here:

    https://github.com/tamasino52/UNETR/blob/main/unetr.py

    It would be great if this can also be included in your repo, since it comes with lots of other great features. So we can explore more.

    Thanks ~

    opened by Siyuan89 3
  • ImageNet Pretrained TimesFormer

    ImageNet Pretrained TimesFormer

    I see you have recently added the TimesFormer model to this repository. In the paper, they initialize their model weights from ImageNet pretrained weights of ViT. Does your implementation offer this too? Thanks!

    opened by RaivoKoot 3
  • Do the encoder modules incorporate positional encoding?

    Do the encoder modules incorporate positional encoding?

    I am wondering if I use say the LinformerEncoder if I have to add the position encoding or if that's already done? From the source files it doesn't seem to be there, but I'm not sure how to include the position encoding as they seem to need the query which isn't available when just passing data directly to the LinformerEncoder. I very well may be missing something any help would be great. Perhaps an example using positional encoding would be good.

    opened by jfkback 3
  • use AxialAttention on gpu

    use AxialAttention on gpu

    I try to use AxialAttention on gpu, but I get a mistake.Can you give me some tips about using AxialAttention on gpu. Thanks! mistake: RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0

    opened by Iverson-Al 2
  • Axial attention

    Axial attention

    What is the meaning of qkv_channels? https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/axial_attention_deeplab/axial_attention.py#L32

    opened by Jayden9912 1
  • Convolution-Free Medical Image Segmentation using Transformers

    Convolution-Free Medical Image Segmentation using Transformers

    Thank you very much for your contribution. As a novice, I have a doubt. In tranf3dseg, the output of the model is the prediction segmentation of the center patch, so how can I get the segmentation of the whole input image? I am looking forward to any reply.

    opened by WinsaW 1
  • Regression with attention

    Regression with attention

    Hello!

    thanks for sharing this nice repo :)

    I'm trying to use ViT to do regression on images. I'd like to predict 6 floats per image.

    My understanding is that I'd need to simply define the network as

    vit = ViT(img_dim=128,
                   in_channels=3,
                   patch_dim=16,
                   num_classes=6,
                   dim=512)
    

    and during training call

    vit(x)
    

    and compute the loss as MSE instead of CE.

    The network actually runs but it doesn't seem to converge. Is there something obvious I am missing?

    many thanks!

    opened by alemelis 1
  • Segmentation for full image

    Segmentation for full image

    Hi,

    Thank you for your effort and time in implementing this. I have a quick question, I want to get segmentation for full image not just for the middle token, would it be correct to change self.tokens to self.p here:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L66

    and change this:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L94

    to

    y = self.mlp_seg_head(y)

    opened by aqibsaeed 0
Releases(1.2.3)
Owner
AI Summer
Learn Deep Learning and Artificial Intelligence
AI Summer
Official Implementation for HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing

HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing Yuval Alaluf*, Omer Tov*, Ron Mokady, Rinon Gal, Amit H. Bermano *Denotes equ

885 Jan 06, 2023
Repositorio oficial del curso IIC2233 Programación Avanzada 🚀✨

IIC2233 - Programación Avanzada Evaluación Las evaluaciones serán efectuadas por medio de actividades prácticas en clases y tareas. Se calculará la no

IIC2233 @ UC 47 Sep 06, 2022
🦕 NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano

🦕 nanosaur NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano Website: nanosaur.ai Do you need an help? Discord For tech

NanoSaur 162 Dec 09, 2022
Neural style transfer as a class in PyTorch

pt-styletransfer Neural style transfer as a class in PyTorch Based on: https://github.com/alexis-jacq/Pytorch-Tutorials Adds: StyleTransferNet as a cl

Tyler Kvochick 31 Jun 27, 2022
CTF Challenge for CSAW Finals 2021

Terminal Velocity Misc CTF Challenge for CSAW Finals 2021 This is a challenge I've had in mind for almost 15 years and never got around to building un

Jordan 6 Jul 30, 2022
CLIP + VQGAN / PixelDraw

clipit Yet Another VQGAN-CLIP Codebase This started as a fork of @nerdyrodent's VQGAN-CLIP code which was based on the notebooks of @RiversWithWings a

dribnet 276 Dec 12, 2022
CARLA: A Python Library to Benchmark Algorithmic Recourse and Counterfactual Explanation Algorithms

CARLA - Counterfactual And Recourse Library CARLA is a python library to benchmark counterfactual explanation and recourse models. It comes out-of-the

Carla Recourse 200 Dec 28, 2022
基于Paddlepaddle复现yolov5,支持PaddleDetection接口

PaddleDetection yolov5 https://github.com/Sharpiless/PaddleDetection-Yolov5 简介 PaddleDetection飞桨目标检测开发套件,旨在帮助开发者更快更好地完成检测模型的组建、训练、优化及部署等全开发流程。 PaddleD

36 Jan 07, 2023
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Transfer-Learn is an open-source and well-documented library for Transfer Learning.

Transfer-Learn is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consist

THUML @ Tsinghua University 2.2k Jan 03, 2023
Garbage Detection system which will detect objects based on whether it is plastic waste or plastics or just garbage.

Garbage Detection using Yolov5 on Jetson Nano 2gb Developer Kit. Garbage detection system which will detect objects based on whether it is plastic was

Rishikesh A. Bondade 2 May 13, 2022
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
Implementation of BI-RADS-BERT & The Advantages of Section Tokenization.

BI-RADS BERT Implementation of BI-RADS-BERT & The Advantages of Section Tokenization. This implementation could be used on other radiology in house co

1 May 17, 2022
A list of all named GANs!

The GAN Zoo Every week, new GAN papers are coming out and it's hard to keep track of them all, not to mention the incredibly creative ways in which re

Avinash Hindupur 12.9k Jan 08, 2023
Alleviating Over-segmentation Errors by Detecting Action Boundaries

Alleviating Over-segmentation Errors by Detecting Action Boundaries Forked from ASRF offical code. This repo is the a implementation of replacing orig

13 Dec 12, 2022
One-Shot Neural Ensemble Architecture Search by Diversity-Guided Search Space Shrinking

One-Shot Neural Ensemble Architecture Search by Diversity-Guided Search Space Shrinking This is an official implementation for NEAS presented in CVPR

Multimedia Research 19 Sep 08, 2022
Pytorch implementation of Implicit Behavior Cloning.

Implicit Behavior Cloning - PyTorch (wip) Pytorch implementation of Implicit Behavior Cloning. Install conda create -n ibc python=3.8 pip install -r r

Kevin Zakka 49 Dec 25, 2022
Code for paper "ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation"

ASAP-Net This project implements ASAP-Net of paper ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation (BMVC2020). Overview We i

Hanwen Cao 26 Aug 25, 2022
ROS support for Velodyne 3D LIDARs

Overview Velodyne1 is a collection of ROS2 packages supporting Velodyne high definition 3D LIDARs3. Warning: The master branch normally contains code

ROS device drivers 543 Dec 30, 2022
A Light in the Dark: Deep Learning Practices for Industrial Computer Vision

A Light in the Dark: Deep Learning Practices for Industrial Computer Vision This is the repository for our Paper/Contribution to the WI2022 in Nürnber

Maximilian Harl 6 Jan 17, 2022