Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Overview

Deformable Attention

Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DETR. The relative positional embedding has also been modified for better extrapolation, using the Continuous Positional Embedding proposed in SwinV2.

Install

$ pip install deformable-attention

Usage

import torch
from deformable_attention import DeformableAttention

attn = DeformableAttention(
    dim = 512,                   # feature dimensions
    dim_head = 64,               # dimension per head
    heads = 8,                   # attention heads
    dropout = 0.,                # dropout
    downsample_factor = 4,       # downsample factor (r in paper)
    offset_scale = 4,            # scale of offset, maximum offset
    offset_groups = None,        # number of offset groups, should be multiple of heads
    offset_kernel_size = 6,      # offset kernel size
)

x = torch.randn(1, 512, 64, 64)
attn(x) # (1, 512, 64, 64)

3d deformable attention

import torch
from deformable_attention import DeformableAttention3D

attn = DeformableAttention3D(
    dim = 512,                          # feature dimensions
    dim_head = 64,                      # dimension per head
    heads = 8,                          # attention heads
    dropout = 0.,                       # dropout
    downsample_factor = (2, 8, 8),      # downsample factor (r in paper)
    offset_scale = (2, 8, 8),           # scale of offset, maximum offset
    offset_kernel_size = (4, 10, 10),   # offset kernel size
)

x = torch.randn(1, 512, 10, 32, 32) # (batch, dimension, frames, height, width)
attn(x) # (1, 512, 10, 32, 32)

1d deformable attention for good measure

import torch
from deformable_attention import DeformableAttention1D

attn = DeformableAttention1D(
    dim = 128,
    downsample_factor = 4,
    offset_scale = 2,
    offset_kernel_size = 6
)

x = torch.randn(1, 128, 512)
attn(x) # (1, 128, 512)

Citation

@misc{xia2022vision,
    title   = {Vision Transformer with Deformable Attention}, 
    author  = {Zhuofan Xia and Xuran Pan and Shiji Song and Li Erran Li and Gao Huang},
    year    = {2022},
    eprint  = {2201.00520},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Some code of the implements of Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network

3D-GMPDCNN Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network PyTorch implementation of "Geological Modeling Usin

MoCoPnet - Deformable 3D Convolution for Video Super-Resolution
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

Deformable 3D Convolution for Video Super-Resolution Pytorch implementation of l

3D2Unet: 3D Deformable Unet for Low-Light Video Enhancement (PRCV2021)

3DDUNET This is the code for 3D2Unet: 3D Deformable Unet for Low-Light Video Enhancement (PRCV2021) Conference Paper Link Dataset We use SMOID dataset

Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

An implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks in PyTorch.

Neural Attention Distillation This is an implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep

Comments
  • The relationship between 'dim' and 'inner_dim'

    The relationship between 'dim' and 'inner_dim'

    Hi, I have a question about DeformableAttention module,

    I calculated the output volumes of the forward processes step by step, According to my calculation, the code works only when 'dim' and 'inner_dim' is same.

    Is there any reason why you implement it this way?

    Best regards, Hankyu

    opened by hanq0212 4
  • TypeError: meshgrid() got an unexpected keyword argument 'indexing'

    TypeError: meshgrid() got an unexpected keyword argument 'indexing'

    @lucidrains while trying to perform import torch from deformable_attention import DeformableAttention

    attn = DeformableAttention( dim = 512, # feature dimensions dim_head = 64, # dimension per head heads = 8, # attention heads dropout = 0., # dropout downsample_factor = 4, # downsample factor (r in paper) offset_scale = 4, # scale of offset, maximum offset offset_groups = None, # number of offset groups, should be multiple of heads offset_kernel_size = 6, # offset kernel size )

    x = torch.randn(1, 512, 64, 64) attn(x)

    Got error below from the line.. Kindly help

    https://github.com/lucidrains/deformable-attention/blob/9f3c0ae35652ce54687e0db409921018bfca3310/deformable_attention/deformable_attention_2d.py#L26

    opened by ChidanandKumarKS 1
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting

1 SNAS4MTF This repo is the official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting. 1.1 The frame

SZJ 5 Sep 21, 2022
PySOT - SenseTime Research platform for single object tracking, implementing algorithms like SiamRPN and SiamMask.

PySOT is a software system designed by SenseTime Video Intelligence Research team. It implements state-of-the-art single object tracking algorit

STVIR 4.1k Dec 29, 2022
BookMyShowPC - Movie Ticket Reservation App made with Tkinter

Book My Show PC What is this? Movie Ticket Reservation App made with Tkinter. Tk

The Nithin Balaji 3 Dec 09, 2022
Pytorch implementation of PTNet for high-resolution and longitudinal infant MRI synthesis

Pyramid Transformer Net (PTNet) Project | Paper Pytorch implementation of PTNet for high-resolution and longitudinal infant MRI synthesis. PTNet: A Hi

Xuzhe Johnny Zhang 6 Jun 08, 2022
PyTorch for Semantic Segmentation

PyTorch for Semantic Segmentation This repository contains some models for semantic segmentation and the pipeline of training and testing models, impl

Zijun Deng 1.7k Jan 06, 2023
Rede Neural Convolucional feita durante o processo seletivo do Laboratório de Inteligência Artificial da FACOM (UFMS)

Primeira_Rede_Neural_Convolucional Rede Neural Convolucional feita durante o processo seletivo do Laboratório de Inteligência Artificial da FACOM (UFM

Roney_Felipe 1 Jan 13, 2022
The Most Efficient Temporal Difference Learning Framework for 2048

moporgic/TDL2048+ TDL2048+ is a highly optimized temporal difference (TD) learning framework for 2048. Features Many common methods related to 2048 ar

Hung Guei 5 Nov 23, 2022
Generative Handwriting using LSTM Mixture Density Network with TensorFlow

Generative Handwriting Demo using TensorFlow An attempt to implement the random handwriting generation portion of Alex Graves' paper. See my blog post

hardmaru 686 Nov 24, 2022
Аналитика доходности инвестиционного портфеля в Тинькофф брокере

Аналитика доходности инвестиционного портфеля Тиньков Видео на YouTube Для работы скрипта нужно установить три переменных окружения: export TINKOFF_TO

Alexey Goloburdin 64 Dec 17, 2022
A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery

PiSL A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery. Sun, F., Liu, Y. and Sun, H., 2021. Physics-informe

Fangzheng (Andy) Sun 8 Jul 13, 2022
Code for the paper "M2m: Imbalanced Classification via Major-to-minor Translation" (CVPR 2020)

M2m: Imbalanced Classification via Major-to-minor Translation This repository contains code for the paper "M2m: Imbalanced Classification via Major-to

79 Oct 13, 2022
Alternatives to Deep Neural Networks for Function Approximations in Finance

Alternatives to Deep Neural Networks for Function Approximations in Finance Code companion repo Overview This is a repository of Python code to go wit

15 Dec 17, 2022
The hippynn python package - a modular library for atomistic machine learning with pytorch.

The hippynn python package - a modular library for atomistic machine learning with pytorch. We aim to provide a powerful library for the training of a

Los Alamos National Laboratory 37 Dec 29, 2022
CasualHealthcare's Pneumonia detection with Artificial Intelligence (Convolutional Neural Network)

CasualHealthcare's Pneumonia detection with Artificial Intelligence (Convolutional Neural Network) This is PneumoniaDiagnose, an artificially intellig

Azhaan 2 Jan 03, 2022
A rough implementation of the paper "A Steering Algorithm for Redirected Walking Using Reinforcement Learning"

A rough implementation of the paper "A Steering Algorithm for Redirected Walking Using Reinforcement Learning"

Somnus `Chen 2 Jun 09, 2022
Fake News Detection Using Machine Learning Methods

Fake-News-Detection-Using-Machine-Learning-Methods Fake news is always a real and dangerous issue. However, with the presence and abundance of various

Achraf Safsafi 1 Jan 11, 2022
The toolkit to generate auto labeled datasets

Ozeu Ozeu is the toolkit to autolabal dataset for instance segmentation. You can generate datasets labaled with segmentation mask and bounding box fro

Xiong Jie 28 Mar 28, 2022
SeqTR: A Simple yet Universal Network for Visual Grounding

SeqTR This is the official implementation of SeqTR: A Simple yet Universal Network for Visual Grounding, which simplifies and unifies the modelling fo

seanZhuh 76 Dec 24, 2022
Can we do Customers Segmentation using PHP and Unsupervized Machine Learning ? Yes we can ! 🤡

Customers Segmentation using PHP and Rubix ML PHP Library Can we do Customers Segmentation using PHP and Unsupervized Machine Learning ? Yes we can !

Mickaël Andrieu 11 Oct 08, 2022