Implementation of the state-of-the-art vision transformers with tensorflow

Overview

ViT Tensorflow

This repository contains the tensorflow implementation of the state-of-the-art vision transformers (a category of computer vision models first introduced in An Image is worth 16 x 16 words). This repository is inspired from the work of lucidrains which is vit-pytorch. I hope you enjoy these implementations :)

Models

Requirements

pip install tensorflow

Vision Transformer

Vision transformer was introduced in An Image is worth 16 x 16 words. This model uses a Transformer encoder to classify images with pure attention and no convolution.

Usage

Defining the Model

from vit import ViT
import tensorflow as tf

vitClassifier = ViT(
                    num_classes=1000,
                    patch_size=16,
                    num_of_patches=(224//16)**2,
                    d_model=128,
                    heads=2,
                    num_layers=4,
                    mlp_rate=2,
                    dropout_rate=0.1,
                    prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = vitClassifier(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

vitClassifier.compile(
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              metrics=[
                       tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
              ])

vitClassifier.fit(
              trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
              validation_data=valData, #The same as training
              epochs=100,)

Convolutional Vision Transformer

Convolutional Vision Transformer was introduced in here. This model uses a hierarchical (multi-stage) architecture with convolutional embeddings in the begining of each stage. it also uses Convolutional Transformer Blocks to improve the orginal vision transformer by adding CNNs inductive bias into the architecture.

Usage

Defining the Model

from cvt import CvT , CvTStage
import tensorflow as tf

cvtModel = CvT(
num_of_classes=1000, 
stages=[
        CvTStage(projectionDim=64, 
                 heads=1, 
                 embeddingWindowSize=(7 , 7), 
                 embeddingStrides=(4 , 4), 
                 layers=1,
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=192,
                 heads=3,
                 embeddingWindowSize=(3 , 3), 
                 embeddingStrides=(2 , 2),
                 layers=1, 
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=384,
                 heads=6,
                 embeddingWindowSize=(3 , 3),
                 embeddingStrides=(2 , 2),
                 layers=1,
                 projectionWindowSize=(3 , 3),
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1)
],
dropout=0.5)
CvT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of CvTStage
    list of cvt stages
  • dropout: float
    dropout rate used for the prediction head
CvTStage Params
  • projectionDim: int
    dimension used for the multi-head attention mechanism and the convolutional embedding
  • heads: int
    number of heads in the multi-head attention mechanism
  • embeddingWindowSize: tuple(int , int)
    window size used for the convolutional emebdding
  • embeddingStrides: tuple(int , int)
    strides used for the convolutional embedding
  • layers: int
    number of convolutional transformer blocks
  • projectionWindowSize: tuple(int , int)
    window size used for the convolutional projection in each convolutional transformer block
  • projectionStrides: tuple(int , int)
    strides used for the convolutional projection in each convolutional transformer block
  • ffnRate: int
    expansion rate of the mlp block in each convolutional transformer block
  • dropoutRate: float
    dropout rate used in each convolutional transformer block

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = cvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

cvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

cvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V1

Pyramid Vision Transformer V1 was introduced in here. This model stacks multiple Transformer Encoders to form the first convolution-free multi-scale backbone for various visual tasks including Image Segmentation , Object Detection and etc. In addition to this a new attention mechanism called Spatial Reduction Attention (SRA) is also introduced in this paper to reduce the quadratic complexity of the multi-head attention mechansim.

Usage

Defining the Model

from pvt_v1 import PVT , PVTStage
import tensorflow as tf

pvtModel = PVT(
num_of_classes=1000, 
stages=[
        PVTStage(d_model=64,
                 patch_size=(2 , 2),
                 heads=1,
                 reductionFactor=2,
                 mlp_rate=2,
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=128,
                 patch_size=(2 , 2),
                 heads=2, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=320,
                 patch_size=(2 , 2),
                 heads=5, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTStage
    list of pvt stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the SRA mechanism and the patch embedding
  • patch_size: tuple(int , int)
    window size used for the patch emebdding
  • heads: int
    number of heads in the SRA mechanism
  • reductionFactor: int
    reduction factor used for the down sampling of the K and V in the SRA mechanism
  • mlp_rate: int
    expansion rate used in the feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

pvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V2

Pyramid Vision Transformer V2 was introduced in here. This model is an improved version of the PVT V1. The improvements of this version are as follows:

  1. It uses overlapping patch embedding by using padded convolutions
  2. It uses convolutional feed-forward blocks which have a depth-wise convolution after the first fully-connected layer
  3. It uses a fixed pooling instead of convolutions for down sampling the K and V in the SRA attention mechanism (The new attention mechanism is called Linear SRA)

Usage

Defining the Model

from pvt_v2 import PVTV2 , PVTV2Stage
import tensorflow as tf

pvtV2Model = PVTV2(
num_of_classes=1000, 
stages=[
        PVTV2Stage(d_model=64,
                   windowSize=(2 , 2), 
                   heads=1,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
        PVTV2Stage(d_model=128, 
                   windowSize=(2 , 2),
                   heads=2,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2,
                   dropout_rate=0.1),
        PVTV2Stage(d_model=320,
                   windowSize=(2 , 2), 
                   heads=5, 
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTV2Stage
    list of pvt v2 stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the Linear SRA mechanism and the convolutional patch embedding
  • windowSize: tuple(int , int)
    window size used for the convolutional patch emebdding
  • heads: int
    number of heads in the Linear SRA mechanism
  • poolingSize: tuple(int , int)
    size of the K and V after the fixed pooling
  • mlp_rate: int
    expansion rate used in the convolutional feed-forward block
  • mlp_windowSize: tuple(int , int)
    the window size used for the depth-wise convolution in the convolutional feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtV2Model(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtV2Model.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
          metrics=[
                   tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
          ])

pvtV2Model.fit(
          trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
          validation_data=valData, #The same as training
          epochs=100,)

DeiT

DeiT was introduced in Training Data-Efficient Image Transformers & Distillation Through Attention. Since original vision transformer is data hungry due to the lack of existance of any inductive bias (unlike CNNs) a lot of data is required to train original vision transformer in order to surpass the state-of-the-art CNNs such as Resnet. Therefore, in this paper authors used a pre-trained CNN such as resent during training and used a sepcial loss function to perform distillation through attention.

Usage

Defining the Model

from deit import DeiT
import tensorflow as tf

teacherModel = tf.keras.applications.ResNet50(include_top=True, 
                                              weights="imagenet", 
                                              input_shape=(224 , 224 , 3))

deitModel = DeiT(
                 num_classes=1000,
                 patch_size=16,
                 num_of_patches=(224//16)**2,
                 d_model=128,
                 heads=2,
                 num_layers=4,
                 mlp_rate=2,
                 teacherModel=teacherModel,
                 temperature=1.0, 
                 alpha=0.5,
                 hard=False, 
                 dropout_rate=0.1,
                 prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • teacherModel: Tensorflow Model
    the teacherModel used for the distillation during training, This model is a pre-trained CNN model with the same input_shape and output_shape as the Transformer
  • temperature: float
    the temperature parameter in the loss
  • alpha: float
    the coefficient balancing the Kullback–Leibler divergence loss (KL) and the cross-entropy loss
  • hard: bool
    indicates using Hard-label distillation or Soft distillation
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = deitModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

#Note that the loss is defined inside the model and no loss should be passed here
deitModel.compile(
         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
         metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                  tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
         ])

deitModel.fit(
         trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b , num_classes))
         validation_data=valData, #The same as training
         epochs=100,)
Owner
Mohammadmahdi NouriBorji
Mohammadmahdi NouriBorji
Simple-Image-Classification - Simple Image Classification Code (PyTorch)

Simple-Image-Classification Simple Image Classification Code (PyTorch) Yechan Kim This repository contains: Python3 / Pytorch code for multi-class ima

Yechan Kim 8 Oct 29, 2022
[ICCV21] Official implementation of the "Social NCE: Contrastive Learning of Socially-aware Motion Representations" in PyTorch.

Social-NCE + CrowdNav Website | Paper | Video | Social NCE + Trajectron | Social NCE + STGCNN This is an official implementation for Social NCE: Contr

VITA lab at EPFL 125 Dec 23, 2022
A tensorflow implementation of GCN-LPA

GCN-LPA This repository is the implementation of GCN-LPA (arXiv): Unifying Graph Convolutional Neural Networks and Label Propagation Hongwei Wang, Jur

Hongwei Wang 83 Nov 28, 2022
PyTorch implementation of DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration (BMVC 2021)

DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration [video] [paper] [supplementary] [data] [thesis] Introduction De

Natalie Lang 10 Dec 14, 2022
Making Structure-from-Motion (COLMAP) more robust to symmetries and duplicated structures

SfM disambiguation with COLMAP About Structure-from-Motion generally fails when the scene exhibits symmetries and duplicated structures. In this repos

Computer Vision and Geometry Lab 193 Dec 26, 2022
ROS support for Velodyne 3D LIDARs

Overview Velodyne1 is a collection of ROS2 packages supporting Velodyne high definition 3D LIDARs3. Warning: The master branch normally contains code

ROS device drivers 543 Dec 30, 2022
Implementation of paper "DeepTag: A General Framework for Fiducial Marker Design and Detection"

Implementation of paper DeepTag: A General Framework for Fiducial Marker Design and Detection. Project page: https://herohuyongtao.github.io/research/

Yongtao Hu 46 Dec 12, 2022
PyTorch code for DriveGAN: Towards a Controllable High-Quality Neural Simulation

PyTorch code for DriveGAN: Towards a Controllable High-Quality Neural Simulation

76 Dec 24, 2022
CV backbones including GhostNet, TinyNet and TNT, developed by Huawei Noah's Ark Lab.

CV Backbones including GhostNet, TinyNet, TNT (Transformer in Transformer) developed by Huawei Noah's Ark Lab. GhostNet Code TinyNet Code TNT Code Pyr

HUAWEI Noah's Ark Lab 3k Jan 08, 2023
This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild with Dense 3D Representations and A Benchmark. (CVPR 2022)"

Gait3D-Benchmark This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild

82 Jan 04, 2023
Hand-distance-measurement-game - Hand Distance Measurement Game

Hand Distance Measurement Game This is program is made to calculate the distance

Priyansh 2 Jan 12, 2022
Official implementation of paper Gradient Matching for Domain Generalization

Gradient Matching for Domain Generalisation This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper,

94 Dec 23, 2022
Python library to receive live stream events like comments and gifts in realtime from TikTok LIVE.

TikTokLive A python library to connect to and read events from TikTok's LIVE service A python library to receive and decode livestream events such as

Isaac Kogan 277 Dec 23, 2022
The open-source and free to use Python package miseval was developed to establish a standardized medical image segmentation evaluation procedure

miseval: a metric library for Medical Image Segmentation EVALuation The open-source and free to use Python package miseval was developed to establish

59 Dec 10, 2022
KSAI Lite is a deep learning inference framework of kingsoft, based on tensorflow lite

KSAI Lite is a deep learning inference framework of kingsoft, based on tensorflow lite

80 Dec 27, 2022
Official PyTorch implementation of the paper Image-Based CLIP-Guided Essence Transfer.

TargetCLIP- official pytorch implementation of the paper Image-Based CLIP-Guided Essence Transfer This repository finds a global direction in StyleGAN

Hila Chefer 221 Dec 13, 2022
An algorithm study of the 6th iOS 10 set of Boost Camp Web Mobile

알고리즘 스터디 🔥 부스트캠프 웹모바일 6기 iOS 10조의 알고리즘 스터디 입니다. 개인적인 사정 등으로 S034, S055만 참가하였습니다. 스터디 목적 상진: 코테 합격 + 부캠끝나고 아침에 일어나기 위해 필요한 사이클 기완: 꾸준하게 자리에 앉아 공부하기 +

2 Jan 11, 2022
[ICML 2021] “ Self-Damaging Contrastive Learning”, Ziyu Jiang, Tianlong Chen, Bobak Mortazavi, Zhangyang Wang

Self-Damaging Contrastive Learning Introduction The recent breakthrough achieved by contrastive learning accelerates the pace for deploying unsupervis

VITA 51 Dec 29, 2022
Programming with Neural Surrogates of Programs

Programming with Neural Surrogates of Programs

0 Dec 12, 2021
My tensorflow implementation of "A neural conversational model", a Deep learning based chatbot

Deep Q&A Table of Contents Presentation Installation Running Chatbot Web interface Results Pretrained model Improvements Upgrade Presentation This wor

Conchylicultor 2.9k Dec 28, 2022