Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Overview

Transformer in Transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch.

Install

$ pip install transformer-in-transformer

Usage

import torch
from transformer_in_transformer import TNT

tnt = TNT(
    image_size = 256,       # size of image
    patch_dim = 512,        # dimension of patch token
    pixel_dim = 24,         # dimension of pixel token
    patch_size = 16,        # patch size
    pixel_size = 4,         # pixel size
    depth = 6,              # depth
    num_classes = 1000,     # output number of classes
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1        # feedforward dropout
)

img = torch.randn(2, 3, 256, 256)
logits = tnt(img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Only works if pixel_size**2 == patch_size?

    Only works if pixel_size**2 == patch_size?

    Hi, is this only supposed to work if

    pixel_size**2 == patch_size 
    

    ?. When setting the patch_size to any number that doesn't fulfill the equation this error occurs:

    --> 146         pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')
        147 
        148         for pixel_attn, pixel_ff, pixel_to_patch_residual, patch_attn, patch_ff in self.layers:
    
    RuntimeError: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1
    

    The error came when running

    tnt = TNT(
        image_size = 128,       # size of image
        patch_dim = 256,        # dimension of patch token
        pixel_dim = 24,         # dimension of pixel token
        patch_size = 16,        # patch size
        pixel_size = 2,         # pixel size
        depth = 6,              # depth
        heads = 1,
        num_classes = 2,     # output number of classes
        attn_dropout = 0.1,     # attention dropout
        ff_dropout = 0.1        # feedforward dropout,
    )
    img = torch.randn(2, 3, 128, 128)
    logits = tnt(img)
    

    Since I am completely new to einops its quite hard for me to debug :D Thanks

    opened by PhilippMarquardt 1
  • Not sure what is wrong!

    Not sure what is wrong!


    RuntimeError Traceback (most recent call last) in 14 15 img = torch.randn(1, 3, 256, 256) ---> 16 logits = tnt(img) # (2, 1000)

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/transformer_in_transformer/tnt.py in forward(self, x) 159 patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b) 160 --> 161 patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d') 162 pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d') 163

    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

    opened by RisabBiswas 0
  • patch_tokens vs patch_pos_emb

    patch_tokens vs patch_pos_emb

    Hi!

    I'm trying to understand your TNT implementation and one thing that got me a bit confused is why there are 2 parameters patch_tokens and patch_pos_emb which seems to have the same purpose - to encode patch position. Isn't one of them redundant?

    self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    ...
    patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b)
    patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d')
    
    opened by stas-sl 0
  • Inconsistent model  params with MindSpore src code

    Inconsistent model params with MindSpore src code

    There's no function or readme description of TNT-S/TNT-B model in this codebase. Something like :

    def tnt_b(num_class):
        return TNT(img_size=384,
                   patch_size=16,
                   num_channels=3,
                   embedding_dim=640,
                   num_heads=10,
                   num_layers=12,
                   hidden_dim=640*4,
                   stride=4,
                   num_class=num_class)
    

    And heads number of inner block should be 4.... https://github.com/lucidrains/transformer-in-transformer/blob/main/transformer_in_transformer/tnt.py#L135

    Wondering if anyone reproduce the paper reported results with this codebase??

    opened by WongChen 0
  • Why the loss become NaN?

    Why the loss become NaN?

    It is a great project. I am very interested in Transformer in Transformer model. I had use your model to train on Vehicle-1M dataset. Vehicle-1M is a fine graied visual classification dataset. When I use this model the loss become NaN after some batch iteration. I had decrease the learning rate of AdamOptimizer and clipping the graident torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2) . But the loss still will become NaN sometimes. It seems that gradients are not big but they are in the same direction for many iterations. How to solve it?

    opened by yt7589 3
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch implementation of DARDet: A Dense Anchor-free Rotated Object Detector in Aerial Images

DARDet PyTorch implementation of "DARDet: A Dense Anchor-free Rotated Object Detector in Aerial Images", [pdf]. Highlights: 1. We develop a new dense

41 Oct 23, 2022
Robustness between the worst and average case

Robustness between the worst and average case A repository that implements intermediate robustness training and evaluation from the NeurIPS 2021 paper

CMU Locus Lab 16 Dec 02, 2022
(AAAI 2021) Progressive One-shot Human Parsing

End-to-end One-shot Human Parsing This is the official repository for our two papers: Progressive One-shot Human Parsing (AAAI 2021) End-to-end One-sh

54 Dec 30, 2022
Deep Sketch-guided Cartoon Video Inbetweening

Cartoon Video Inbetweening Paper | DOI | Video The source code of Deep Sketch-guided Cartoon Video Inbetweening by Xiaoyu Li, Bo Zhang, Jing Liao, Ped

Xiaoyu Li 37 Dec 22, 2022
This repository holds code and data for our PETS'22 article 'From "Onion Not Found" to Guard Discovery'.

From "Onion Not Found" to Guard Discovery (PETS'22) This repository holds the code and data for our PETS'22 paper titled 'From "Onion Not Found" to Gu

Lennart Oldenburg 3 May 04, 2022
Implementation of the algorithm shown in the article "Modelo de Predicción de Éxito de Canciones Basado en Descriptores de Audio"

Success Predictor Implementation of the algorithm shown in the article "Modelo de Predicción de Éxito de Canciones Basado en Descriptores de Audio". B

Rodrigo Nazar Meier 4 Mar 17, 2022
Auto grind btdb2 exp for tower

Bloons TD Battles 2 EXP Grinder Auto grind btdb2 exp for towers Setup I suggest checking out every screenshot to see what they are supposed to be, so

Vincent 6 Jul 29, 2022
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

Xingyu Lin 93 Jan 05, 2023
🤖 Project template for your next awesome AI project. 🦾

🤖 AI Awesome Project Template 👋 Template author You may want to adjust badge links in a README.md file. 💎 Installation with pip Installation is as

Wiktor Łazarski 18 Nov 23, 2022
Using Self-Supervised Pretext Tasks for Active Learning - Official Pytorch Implementation

Using Self-Supervised Pretext Tasks for Active Learning - Official Pytorch Implementation Experiment Setting: CIFAR10 (downloaded and saved in ./DATA

John Seon Keun Yi 38 Dec 27, 2022
A PyTorch implementation of QANet.

QANet-pytorch NOTICE I'm very busy these months. I'll return to this repo in about 10 days. Introduction An implementation of QANet with PyTorch. Any

H. Z. 343 Nov 03, 2022
UNet model with VGG11 encoder pre-trained on Kaggle Carvana dataset

TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation By Vladimir Iglovikov and Alexey Shvets Introduction TernausNet is

Vladimir Iglovikov 1k Dec 28, 2022
HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow

Class HiddenMarkovModel HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow 2.0 Installatio

Susara Thenuwara 2 Nov 03, 2021
Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2

CoaDTI Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2 Abstract Environment The test was conducted i

Layne_Huang 7 Nov 14, 2022
Codes for building and training the neural network model described in Domain-informed neural networks for interaction localization within astroparticle experiments.

Domain-informed Neural Networks Codes for building and training the neural network model described in Domain-informed neural networks for interaction

DIDACTS 0 Dec 13, 2021
Demo code for ICCV 2021 paper "Sensor-Guided Optical Flow"

Sensor-Guided Optical Flow Demo code for "Sensor-Guided Optical Flow", ICCV 2021 This code is provided to replicate results with flow hints obtained f

10 Mar 16, 2022
Implementation of the "PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences" paper.

PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences Introduction Point cloud sequences are irregular and unordered in the spatial dimen

Hehe Fan 63 Dec 09, 2022
Implementation of paper "Graph Condensation for Graph Neural Networks"

GCond A PyTorch implementation of paper "Graph Condensation for Graph Neural Networks" Code will be released soon. Stay tuned :) Abstract We propose a

Wei Jin 66 Dec 04, 2022
A higher performance pytorch implementation of DeepLab V3 Plus(DeepLab v3+)

A Higher Performance Pytorch Implementation of DeepLab V3 Plus Introduction This repo is an (re-)implementation of Encoder-Decoder with Atrous Separab

linhua 326 Nov 22, 2022
This is a repo of basic Machine Learning!

Basic Machine Learning This repository contains a topic-wise curated list of Machine Learning and Deep Learning tutorials, articles and other resource

Ekram Asif 53 Dec 31, 2022