Tensorflow implementation of Swin Transformer model.

Overview

Swin Transformer (Tensorflow)

Tensorflow reimplementation of Swin Transformer model.

Based on Official Pytorch implementation. image

Requirements

  • tensorflow >= 2.4.1

Pretrained Swin Transformer Checkpoints

ImageNet-1K and ImageNet-22K Pretrained Checkpoints

name pretrain resolution [email protected] #params model
swin_tiny_224 ImageNet-1K 224x224 81.2 28M github
swin_small_224 ImageNet-1K 224x224 83.2 50M github
swin_base_224 ImageNet-22K 224x224 85.2 88M github
swin_base_384 ImageNet-22K 384x384 86.4 88M github
swin_large_224 ImageNet-22K 224x224 86.3 197M github
swin_large_384 ImageNet-22K 384x384 87.3 197M github

Examples

Initializing the model:

from swintransformer import SwinTransformer

model = SwinTransformer('swin_tiny_224', num_classes=1000, include_top=True, pretrained=False)

You can use a pretrained model like this:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

If you use a pretrained model with TPU on kaggle, specify use_tpu option:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

Example: TPU training on Kaggle

Citation

@article{liu2021Swin,
  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
  journal={arXiv preprint arXiv:2103.14030},
  year={2021}
}
Comments
  • no module name 'swintransformer' error

    no module name 'swintransformer' error

    I wounder where the from swintransformer import SwinTransformer come from? I tried to pip install it, it also said that there is no such module. How can I overcome this problem?

    opened by HunarAA 2
  • Pretrained Swin-Transformer for multiple output

    Pretrained Swin-Transformer for multiple output

    Hi rishigami,

    Thank you for the implementation in Tensorflow. I am trying to use the Swin Transformer for a classification problem with multiple outputs. In your guide on how to use a pertained model you put it in a Sequential mode, but in this way I am not able to stack multiple dense layer for the multiple classification, could you help me understand how can I adapt your TF code to my problem, using it in a Functional API way maybe?

    opened by imanuelroz 2
  • NotImplementedError during model save

    NotImplementedError during model save

    I have defined a model as follows:

    def buildModel(LR = LR):
        backbone = SwinTransformer('swin_large_224', num_classes=None, include_top=False, pretrained=True, use_tpu=False)
        
        inp = L.Input(shape=(224,224,3))
        emb = backbone(inp)
        out = L.Dense(1,activation="relu")(emb)
        
        model = tf.keras.Model(inputs=inp,outputs=out)
        optimizer = tf.keras.optimizers.Adam(lr = LR)
        model.compile(loss="mse",optimizer=optimizer,metrics=[tf.keras.metrics.RootMeanSquaredError()])
        return model
    

    Now when I save this model using model.save("./model.hdf5") I get the following error:

    NotImplementedError                       Traceback (most recent call last)
    /tmp/ipykernel_43/131311624.py in <module>
    ----> 1 model.save("model.hdf5")
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
       2000     # pylint: enable=line-too-long
       2001     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
    -> 2002                     signatures, options, save_traces)
       2003 
       2004   def save_weights(self,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
        152           'or using `save_weights`.')
        153     hdf5_format.save_model_to_hdf5(
    --> 154         model, filepath, overwrite, include_optimizer)
        155   else:
        156     saved_model_save.save(model, filepath, overwrite, include_optimizer,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
        113 
        114   try:
    --> 115     model_metadata = saving_utils.model_metadata(model, include_optimizer)
        116     for k, v in model_metadata.items():
        117       if isinstance(v, (dict, list, tuple)):
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        156   except NotImplementedError as e:
        157     if require_config:
    --> 158       raise e
        159 
        160   metadata = dict(
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        153   model_config = {'class_name': model.__class__.__name__}
        154   try:
    --> 155     model_config['config'] = model.get_config()
        156   except NotImplementedError as e:
        157     if require_config:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_config(self)
        648 
        649   def get_config(self):
    --> 650     return copy.deepcopy(get_network_config(self))
        651 
        652   @classmethod
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_network_config(network, serialize_layer_fn)
       1347         filtered_inbound_nodes.append(node_data)
       1348 
    -> 1349     layer_config = serialize_layer_fn(layer)
       1350     layer_config['name'] = layer.name
       1351     layer_config['inbound_nodes'] = filtered_inbound_nodes
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        248         return serialize_keras_class_and_config(
        249             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
    --> 250       raise e
        251     serialization_config = {}
        252     for key, item in config.items():
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        243     name = get_registered_name(instance.__class__)
        244     try:
    --> 245       config = instance.get_config()
        246     except NotImplementedError as e:
        247       if _SKIP_FAILED_SERIALIZATION:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in get_config(self)
       2252 
       2253   def get_config(self):
    -> 2254     raise NotImplementedError
       2255 
       2256   @classmethod
    
    NotImplementedError: 
    
    opened by Bibhash123 1
  • Invalid argument

    Invalid argument

    this is my basic model

    
    with tpu_strategy.scope():
        model = tf.keras.Sequential([
                            tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(data, mode="torch"), 
                                                                input_shape=[224,224, 3]),
                            SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
                            tf.keras.layers.Dense(1, activation='sigmoid')
                                            ])
    
    model.compile(loss = tf.keras.losses.BinaryCrossentropy(),
                              optimizer = tf.keras.optimizers.Adam(learning_rate=cfg['LEARNING_RATE']),
                              metrics   = RMSE)
    
    

    I am getting this error,

    (3) Invalid argument: {{function_node __inference_train_function_705020}} Reshape's input dynamic dimension is decomposed into multiple output dynamic dimensions, but the constraint is ambiguous and XLA can't infer the output dimension %reshape.12202 = f32[256,144,576]{2,1,0} reshape(f32[36864,576]{1,0} %transpose.12194), metadata={op_type="Reshape" op_name="sequential_40/swin_large_384/sequential_39/basic_layer_28/sequential_35/swin_transformer_block_169/window_attention_169/layers0/blocks1/attn/qkv/Tensordot"}. [[{{node TPUReplicate/_compile/_17658394825749957328/_4}}]] [[tpu_compile_succeeded_assert/_11424487196827204192/_5/_209]]

    opened by AliKayhanAtay 1
  • relative_position_bias_table initialization

    relative_position_bias_table initialization

    Hi, In the official code, relative_position_bias_table is initialized in a truncated normal distribution. Is that part missing in this repo?

    Official code: https://github.com/microsoft/Swin-Transformer/blob/6bbd83ca617db8480b2fb9b335c476ffaf5afb1a/models/swin_transformer.py#L110

    This implem https://github.com/rishigami/Swin-Transformer-TF/blob/8986ca7b0e1f984437db2d8f17e0ecd87fadcd4f/swintransformer/model.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L70

    opened by gathierry 1
  • Image size other than default ones doesn't work

    Image size other than default ones doesn't work

    • Notebook: https://colab.research.google.com/drive/1nqYkQCUzShkVdqGxW4TyMrtAb0n5MBZR#scrollTo=G9ZVlphmqD7d Issue:
    • In swin_tiny_224 I've tried multiple of 224, 512x512, multiple of window_size. But nothing seems to work other than the 224x224.
    • Same goes for swin_large_384, only default size 384x384 works.

    I'm wondering if this is expected behavior or not. Is there any way to make it work for non-square image?

    opened by awsaf49 1
  • Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Tested and working, ie:

    IMAGE_SIZE = [112, 112, 112]
    NUM_CLASSES = 10
    
    model_3d = tf.keras.Sequential([
      swin_transformer_nd.SwinTransformerModel(img_size=IMAGE_SIZE, patch_size=(4, 4, 4), depths=[2, 2, 6]),
      tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    model_3d.compile(tf.keras.optimizers.Adam(), "categorical_crossentropy")
    
    for i in range(100):
        x = np.zeros([1, *IMAGE_SIZE, 1])
        y = tf.zeros([1, NUM_CLASSES])
        
        model_3d.fit(x, y)
        print("Trained on a batch")
    
    opened by MohamadZeina 0
  • Could you provide weights convert script?

    Could you provide weights convert script?

    I tried code and weights you provided, and find the performance is bad. Could you pleaase to provide weights convert script for me to figure out this issue?

    Many thanks

    opened by edwardyehuang 0
  • tf load model is erro

    tf load model is erro

    import tensorflow as tf from swintransformer import SwinTransformer model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]), SwinTransformer('swin_tiny_224', include_top=False, pretrained=True), tf.keras.layers.Dense(NUM_CLASSES, activation='softmax') ])

    tf can't load pre trained model。this step is errro

    opened by jangjiun 0
  • Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Has anyone tried to use the pretrained model with TimeDistributed layer ?

    model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), 
    input_shape=[224,224, 3]), SwinTransformer('swin_base_224', include_top=False, pretrained=True)])
    
    model_f = models.Sequential()
    	model.add(TimeDistributed(model, input_shape= (8,224,224,3)) 
    
    

    I get the following error:

    NotImplementedError: Exception encountered when calling layer "time_distributed" (type TimeDistributed).
    
    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel).
    
    Call arguments received by layer "time_distributed" (type TimeDistributed):
      • inputs=tf.Tensor(shape=(None, 8, 224, 224, 3), dtype=float32)
      • training=False
    
    
    opened by atelili 0
Releases(v0.1-tf-swin-weights)
Official repository of the paper "A Variational Approximation for Analyzing the Dynamics of Panel Data". Mixed Effect Neural ODE. UAI 2021.

Official repository of the paper (UAI 2021) "A Variational Approximation for Analyzing the Dynamics of Panel Data", Mixed Effect Neural ODE. Panel dat

Jurijs Nazarovs 7 Nov 26, 2022
StocksMA is a package to facilitate access to financial and economic data of Moroccan stocks.

Creating easier access to the Moroccan stock market data What is StocksMA ? StocksMA is a package to facilitate access to financial and economic data

Salah Eddine LABIAD 28 Jan 04, 2023
Adaptable tools to make reinforcement learning and evolutionary computation algorithms.

Pearl The Parallel Evolutionary and Reinforcement Learning Library (Pearl) is a pytorch based package with the goal of being excellent for rapid proto

38 Jan 01, 2023
Unofficial Pytorch Implementation of WaveGrad2

WaveGrad 2 — Unofficial PyTorch Implementation WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis Unofficial PyTorch+Lightning Implementati

MINDs Lab 104 Nov 29, 2022
Official Implementation of DDOD (Disentangle your Dense Object Detector), ACM MM2021

Disentangle Your Dense Object Detector This repo contains the supported code and configuration files to reproduce object detection results of Disentan

loveSnowBest 51 Jan 07, 2023
Dahua Camera and Doorbell Home Assistant Integration

Home Assistant Dahua Integration The Dahua Home Assistant integration allows you to integrate your Dahua cameras and doorbells in Home Assistant. It's

Ronnie 216 Dec 26, 2022
CPF: Learning a Contact Potential Field to Model the Hand-object Interaction

Contact Potential Field This repo contains model, demo, and test codes of our paper: CPF: Learning a Contact Potential Field to Model the Hand-object

Lixin YANG 99 Dec 26, 2022
Experiments for Fake News explainability project

fake-news-explainability Experiments for fake news explainability project This repository only contains the notebooks used to train the models and eva

Lorenzo Flores (Lj) 1 Dec 03, 2022
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Dec 27, 2022
Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks

OnsagerNet Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks This is the original pyTorch implemenati

Haijun.Yu 3 Aug 24, 2022
A generalized framework for prototyping full-stack cooperative driving automation applications under CARLA+SUMO.

OpenCDA OpenCDA is a SIMULATION tool integrated with a prototype cooperative driving automation (CDA; see SAE J3216) pipeline as well as regular autom

UCLA Mobility Lab 726 Dec 29, 2022
An unopinionated replacement for PyTorch's Dataset and ImageFolder, that handles Tar archives

Simple Tar Dataset An unopinionated replacement for PyTorch's Dataset and ImageFolder classes, for datasets stored as uncompressed Tar archives. Just

Joao Henriques 47 Dec 20, 2022
Official pytorch implementation of paper Dual-Level Collaborative Transformer for Image Captioning (AAAI 2021).

Dual-Level Collaborative Transformer for Image Captioning This repository contains the reference code for the paper Dual-Level Collaborative Transform

lyricpoem 160 Dec 11, 2022
A PyTorch implementation of "From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network" (ICCV2021)

From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network The official code of VisionLAN (ICCV2021). VisionLAN successfully a

81 Dec 12, 2022
Multi-Scale Progressive Fusion Network for Single Image Deraining

Multi-Scale Progressive Fusion Network for Single Image Deraining (MSPFN) This is an implementation of the MSPFN model proposed in the paper (Multi-Sc

Kuijiang 128 Nov 21, 2022
"Structure-Augmented Text Representation Learning for Efficient Knowledge Graph Completion"(WWW 2021)

STAR_KGC This repo contains the source code of the paper accepted by WWW'2021. "Structure-Augmented Text Representation Learning for Efficient Knowled

Bo Wang 60 Dec 26, 2022
Code for "PVNet: Pixel-wise Voting Network for 6DoF Pose Estimation" CVPR 2019 oral

Good news! We release a clean version of PVNet: clean-pvnet, including how to train the PVNet on the custom dataset. Use PVNet with a detector. The tr

ZJU3DV 722 Dec 27, 2022
An unofficial PyTorch implementation of a federated learning algorithm, FedAvg.

Federated Averaging (FedAvg) in PyTorch An unofficial implementation of FederatedAveraging (or FedAvg) algorithm proposed in the paper Communication-E

Seok-Ju Hahn 123 Jan 06, 2023
Syntax-Aware Action Targeting for Video Captioning

Syntax-Aware Action Targeting for Video Captioning Code for SAAT from "Syntax-Aware Action Targeting for Video Captioning" (Accepted to CVPR 2020). Th

59 Oct 13, 2022
A Python reference implementation of the CF data model

cfdm A Python reference implementation of the CF data model. References Compliance with FAIR principles Documentation https://ncas-cms.github.io/cfdm

NCAS CMS 25 Dec 13, 2022