Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Overview

Transformer in Transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch.

Install

$ pip install transformer-in-transformer

Usage

import torch
from transformer_in_transformer import TNT

tnt = TNT(
    image_size = 256,       # size of image
    patch_dim = 512,        # dimension of patch token
    pixel_dim = 24,         # dimension of pixel token
    patch_size = 16,        # patch size
    pixel_size = 4,         # pixel size
    depth = 6,              # depth
    num_classes = 1000,     # output number of classes
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1        # feedforward dropout
)

img = torch.randn(2, 3, 256, 256)
logits = tnt(img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Only works if pixel_size**2 == patch_size?

    Only works if pixel_size**2 == patch_size?

    Hi, is this only supposed to work if

    pixel_size**2 == patch_size 
    

    ?. When setting the patch_size to any number that doesn't fulfill the equation this error occurs:

    --> 146         pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')
        147 
        148         for pixel_attn, pixel_ff, pixel_to_patch_residual, patch_attn, patch_ff in self.layers:
    
    RuntimeError: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1
    

    The error came when running

    tnt = TNT(
        image_size = 128,       # size of image
        patch_dim = 256,        # dimension of patch token
        pixel_dim = 24,         # dimension of pixel token
        patch_size = 16,        # patch size
        pixel_size = 2,         # pixel size
        depth = 6,              # depth
        heads = 1,
        num_classes = 2,     # output number of classes
        attn_dropout = 0.1,     # attention dropout
        ff_dropout = 0.1        # feedforward dropout,
    )
    img = torch.randn(2, 3, 128, 128)
    logits = tnt(img)
    

    Since I am completely new to einops its quite hard for me to debug :D Thanks

    opened by PhilippMarquardt 1
  • Not sure what is wrong!

    Not sure what is wrong!


    RuntimeError Traceback (most recent call last) in 14 15 img = torch.randn(1, 3, 256, 256) ---> 16 logits = tnt(img) # (2, 1000)

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/transformer_in_transformer/tnt.py in forward(self, x) 159 patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b) 160 --> 161 patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d') 162 pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d') 163

    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

    opened by RisabBiswas 0
  • patch_tokens vs patch_pos_emb

    patch_tokens vs patch_pos_emb

    Hi!

    I'm trying to understand your TNT implementation and one thing that got me a bit confused is why there are 2 parameters patch_tokens and patch_pos_emb which seems to have the same purpose - to encode patch position. Isn't one of them redundant?

    self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    ...
    patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b)
    patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d')
    
    opened by stas-sl 0
  • Inconsistent model  params with MindSpore src code

    Inconsistent model params with MindSpore src code

    There's no function or readme description of TNT-S/TNT-B model in this codebase. Something like :

    def tnt_b(num_class):
        return TNT(img_size=384,
                   patch_size=16,
                   num_channels=3,
                   embedding_dim=640,
                   num_heads=10,
                   num_layers=12,
                   hidden_dim=640*4,
                   stride=4,
                   num_class=num_class)
    

    And heads number of inner block should be 4.... https://github.com/lucidrains/transformer-in-transformer/blob/main/transformer_in_transformer/tnt.py#L135

    Wondering if anyone reproduce the paper reported results with this codebase??

    opened by WongChen 0
  • Why the loss become NaN?

    Why the loss become NaN?

    It is a great project. I am very interested in Transformer in Transformer model. I had use your model to train on Vehicle-1M dataset. Vehicle-1M is a fine graied visual classification dataset. When I use this model the loss become NaN after some batch iteration. I had decrease the learning rate of AdamOptimizer and clipping the graident torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2) . But the loss still will become NaN sometimes. It seems that gradients are not big but they are in the same direction for many iterations. How to solve it?

    opened by yt7589 3
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation. Intel iHD GPU (iGPU) support. NVIDIA GPU (dGPU) support.

mtomo Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation.

Katsuya Hyodo 24 Mar 02, 2022
Code for "Layered Neural Rendering for Retiming People in Video."

Layered Neural Rendering in PyTorch This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "Layered Neural Rendering

Google 154 Dec 16, 2022
Official PyTorch implementation of Spatial Dependency Networks.

Spatial Dependency Networks: Neural Layers for Improved Generative Image Modeling Đorđe Miladinović   Aleksandar Stanić   Stefan Bauer   Jürgen Schmid

Djordje Miladinovic 34 Jan 19, 2022
Code and data for the EMNLP 2021 paper "Just Say No: Analyzing the Stance of Neural Dialogue Generation in Offensive Contexts". Coming soon!

ToxiChat Code and data for the EMNLP 2021 paper "Just Say No: Analyzing the Stance of Neural Dialogue Generation in Offensive Contexts". Install depen

Ashutosh Baheti 11 Jan 01, 2023
Code related to the manuscript "Averting A Crisis In Simulation-Based Inference"

Abstract We present extensive empirical evidence showing that current Bayesian simulation-based inference algorithms are inadequate for the falsificat

Montefiore Artificial Intelligence Research 3 Nov 14, 2022
BanditPAM: Almost Linear-Time k-Medoids Clustering

BanditPAM: Almost Linear-Time k-Medoids Clustering This repo contains a high-performance implementation of BanditPAM from BanditPAM: Almost Linear-Tim

254 Dec 12, 2022
Official PyTorch(Geometric) implementation of DPGNN(DPGCN) in "Distance-wise Prototypical Graph Neural Network for Node Imbalance Classification"

DPGNN This repository is an official PyTorch(Geometric) implementation of DPGNN(DPGCN) in "Distance-wise Prototypical Graph Neural Network for Node Im

Yu Wang (Jack) 18 Oct 12, 2022
blind SQLIpy sebuah alat injeksi sql yang menggunakan waktu sql untuk mendapatkan sebuah server database.

blind SQLIpy Alat blind SQLIpy ini merupakan alat injeksi sql yang menggunakan metode time based blind sql injection metode tersebut membutuhkan waktu

Galih Anggoro Prasetya 4 Feb 24, 2022
Image data augmentation scheduler for albumentations transforms

albu_scheduler Scheduler for albumentations transforms based on PyTorch schedulers interface Usage TransformMultiStepScheduler import albumentations a

19 Aug 04, 2021
The repository contains reproducible PyTorch source code of our paper Generative Modeling with Optimal Transport Maps, ICLR 2022.

Generative Modeling with Optimal Transport Maps The repository contains reproducible PyTorch source code of our paper Generative Modeling with Optimal

Litu Rout 30 Dec 22, 2022
This repository contains the source code and data for reproducing results of Deep Continuous Clustering paper

Deep Continuous Clustering Introduction This is a Pytorch implementation of the DCC algorithms presented in the following paper (paper): Sohil Atul Sh

Sohil Shah 197 Nov 29, 2022
A project which aims to protect your privacy using inexpensive hardware and easily modifiable software

Protecting your privacy using an ESP32, an IR sensor and a python script This project, which I personally call the "never-gonna-catch-me-in-the-act-ev

8 Oct 10, 2022
Disentangled Face Attribute Editing via Instance-Aware Latent Space Search, accepted by IJCAI 2021.

Instance-Aware Latent-Space Search This is a PyTorch implementation of the following paper: Disentangled Face Attribute Editing via Instance-Aware Lat

67 Dec 21, 2022
Seg-Torch for Image Segmentation with Torch

Seg-Torch for Image Segmentation with Torch This work was sparked by my personal research on simple segmentation methods based on deep learning. It is

Eren Gölge 37 Dec 12, 2022
A FAIR dataset of TCV experimental results for validating edge/divertor turbulence models.

TCV-X21 validation for divertor turbulence simulations Quick links Intro Welcome to TCV-X21. We're glad you've found us! This repository is designed t

0 Dec 18, 2021
A Neural Net Training Interface on TensorFlow, with focus on speed + flexibility

Tensorpack is a neural network training interface based on TensorFlow. Features: It's Yet Another TF high-level API, with speed, and flexibility built

Tensorpack 6.2k Jan 01, 2023
Simple ONNX operation generator. Simple Operation Generator for ONNX.

sog4onnx Simple ONNX operation generator. Simple Operation Generator for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools Key concept V

Katsuya Hyodo 6 May 15, 2022
Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation

Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation Introduction This is a PyTorch

XMed-Lab 30 Sep 23, 2022
PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot

Progressive Growing of GANs inference in PyTorch with CelebA training snapshot Description This is an inference sample written in PyTorch of the origi

320 Nov 21, 2022
The official implementation of EIGNN: Efficient Infinite-Depth Graph Neural Networks (NeurIPS 2021)

EIGNN: Efficient Infinite-Depth Graph Neural Networks The official implementation of EIGNN: Efficient Infinite-Depth Graph Neural Networks (NeurIPS 20

Juncheng Liu 14 Nov 22, 2022