Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Overview

Swin-Transformer-Tensorflow

A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" to TensorFlow 2.

The official Pytorch implementation can be found here.

Introduction:

Swin Transformer Architecture Diagram

Swin Transformer (the name Swin stands for Shifted window) is initially described in arxiv, which capably serves as a general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.

Swin Transformer achieves strong performance on COCO object detection (58.7 box AP and 51.1 mask AP on test-dev) and ADE20K semantic segmentation (53.5 mIoU on val), surpassing previous models by a large margin.

Usage:

1. To Run a Pre-trained Swin Transformer

Swin-T:

python main.py --cfg configs/swin_tiny_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-S:

python main.py --cfg configs/swin_small_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-B:

python main.py --cfg configs/swin_base_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

The possible options for cfg and weights_type are:

cfg weights_type 22K model 1K Model
configs/swin_tiny_patch4_window7_224.yaml imagenet_1k - github
configs/swin_small_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22k github -
configs/swin_base_patch4_window12_384.yaml imagenet_22k github -
configs/swin_large_patch4_window7_224.yaml imagenet_22k github -
configs/swin_large_patch4_window12_384.yaml imagenet_22k github -

2. Create custom models

To create a custom classification model:

import argparse

import tensorflow as tf

from config import get_config
from models.build import build_model

parser = argparse.ArgumentParser('Custom Swin Transformer')

parser.add_argument(
    '--cfg',
    type=str,
    metavar="FILE",
    help='path to config file',
    default="CUSTOM_YAML_FILE_PATH"
)
parser.add_argument(
    '--resume',
    type=int,
    help='Whether or not to resume training from pretrained weights',
    choices={0, 1},
    default=1,
)
parser.add_argument(
    '--weights_type',
    type=str,
    help='Type of pretrained weight file to load including number of classes',
    choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"},
    default="imagenet_1k",
)

args = parser.parse_args()
custom_config = get_config(args, include_top=False)

swin_transformer = tf.keras.Sequential([
    build_model(config=custom_config, load_pretrained=args.resume, weights_type=args.weights_type),
    tf.keras.layers.Dense(CUSTOM_NUM_CLASSES)
)

Model ouputs are logits, so don't forget to include softmax in training/inference!!

You can easily customize the model configs with custom YAML files. Predefined YAML files provided by Microsoft are located in the configs directory.

3. Convert PyTorch pretrained weights into Tensorflow checkpoints

We provide a python script with which we convert official PyTorch weights into Tensorflow checkpoints.

$ python convert_weights.py --cfg config_file --weights the_path_to_pytorch_weights --weights_type type_of_pretrained_weights --output the_path_to_output_tf_weights

TODO:

  • Translate model code over to TensorFlow
  • Load PyTorch pretrained weights into TensorFlow model
  • Write trainer code
  • Reproduce results presented in paper
    • Object Detection
  • Reproduce training efficiency of official code in TensorFlow

Citations:

@misc{liu2021swin,
      title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 
      author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
      year={2021},
      eprint={2103.14030},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

Non-Official Pytorch implementation of
Non-Official Pytorch implementation of "Face Identity Disentanglement via Latent Space Mapping" https://arxiv.org/abs/2005.07728 Using StyleGAN2 instead of StyleGAN

Face Identity Disentanglement via Latent Space Mapping - Implement in pytorch with StyleGAN 2 Description Pytorch implementation of the paper Face Ide

Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)
Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)

Transfer Learning for Text Classification with Tensorflow Tensorflow implementation of Semi-supervised Sequence Learning(https://arxiv.org/abs/1511.01

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Comments
  • Custom Swin Transformer: error: unrecognized arguments

    Custom Swin Transformer: error: unrecognized arguments

    parser = argparse.ArgumentParser('Custom Swin Transformer')

    parser.add_argument( '--cfg', type=str, metavar="FILE", help='/content/Swin-Transformer-Tensorflow/configs/swin_tiny_patch4_window7_224.yaml', default="CUSTOM_YAML_FILE_PATH" ) parser.add_argument( '--resume', type=int, help=1, choices={0, 1}, default=1, ) parser.add_argument( '--weights_type', type=str, help='imagenet_22k', choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"}, default="imagenet_1k", )

    args = parser.parse_args() custom_config = get_config(args, include_top=False)

    i am trying to use it but it throws an error below

    usage: Custom Swin Transformer [-h] [--cfg FILE] [--resume {0,1}] [--weights_type {imagenet_22kto1k,imagenet_1k,imagenet_22k}] Custom Swin Transformer: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-ee309a98-1f20-4bb7-aa12-c2980aea076c.json An exception has occurred, use %tb to see the full traceback.

    SystemExit: 2

    opened by AliKayhanAtay 1
  • train dataset

    train dataset

    Thank you for Thank you for providing your code. I've been running the pretrained model, and I'd like to know how to learn about custom data from the code you provided and how to transfer learning to custom data using the pretrained model. Thank you.

    opened by hoyeoung 1
Spectralformer: Rethinking hyperspectral image classification with transformers

Spectralformer: Rethinking hyperspectral image classification with transformers Danfeng Hong, Zhu Han, Jing Yao, Lianru Gao, Bing Zhang, Antonio Plaza

Danfeng Hong 102 Dec 29, 2022
A Jinja extension (compatible with Flask and other frameworks) to compile and/or compress your assets.

A Jinja extension (compatible with Flask and other frameworks) to compile and/or compress your assets.

Jayson Reis 94 Nov 21, 2022
EfficientDet (Scalable and Efficient Object Detection) implementation in Keras and Tensorflow

EfficientDet This is an implementation of EfficientDet for object detection on Keras and Tensorflow. The project is based on the official implementati

1.3k Dec 19, 2022
ML models implementation practice

Let's implement various ML algorithms with numpy/tf Vanilla Neural Network https://towardsdatascience.com/lets-code-a-neural-network-in-plain-numpy-ae

Jinsoo Heo 4 Jul 04, 2021
Distributional Sliced-Wasserstein distance code

Distributional Sliced Wasserstein distance This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Genera

VinAI Research 39 Jan 01, 2023
Real-time Neural Representation Fusion for Robust Volumetric Mapping

NeuralBlox: Real-Time Neural Representation Fusion for Robust Volumetric Mapping Paper | Supplementary This repository contains the implementation of

ETHZ ASL 106 Dec 24, 2022
Help you understand Manual and w/ Clutch point while driving.

简体中文 forza_auto_gear forza_auto_gear is a tool for Forza Horizon 5. It will help us understand the best gear shift point using Manual or w/ Clutch in

15 Oct 08, 2022
Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Kalpesh Krishna 41 Nov 08, 2022
Implementation of ToeplitzLDA for spatiotemporal stationary time series data.

Code for the ToeplitzLDA classifier proposed in here. The classifier conforms sklearn and can be used as a drop-in replacement for other LDA classifiers. For in-depth usage refer to the learning from

Jan Sosulski 5 Nov 07, 2022
RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation

RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation RL-GAN is an official implementation of the paper: T

42 Nov 10, 2022
Regularized Frank-Wolfe for Dense CRFs: Generalizing Mean Field and Beyond

CRF - Conditional Random Fields A library for dense conditional random fields (CRFs). This is the official accompanying code for the paper Regularized

Đ.Khuê Lê-Huu 21 Nov 26, 2022
Face and Pose detector that emits MQTT events when a face or human body is detected and not detected.

Face Detect MQTT Face or Pose detector that emits MQTT events when a face or human body is detected and not detected. I built this as an alternative t

Jacob Morris 38 Oct 21, 2022
PAthological QUpath Obsession - QuPath and Python conversations

PAQUO: PAthological QUpath Obsession Welcome to paquo 👋 , a library for interacting with QuPath from Python. paquo's goal is to provide a pythonic in

Bayer AG 60 Dec 31, 2022
An unofficial personal implementation of UM-Adapt, specifically to tackle joint estimation of panoptic segmentation and depth prediction for autonomous driving datasets.

Semisupervised Multitask Learning This repository is an unofficial and slightly modified implementation of UM-Adapt[1] using PyTorch. This code primar

Abhinav Atrishi 11 Nov 25, 2022
Code for Environment Dynamics Decomposition (ED2).

ED2 Code for Environment Dynamics Decomposition (ED2). Installation Follow the installation in MBPO and Dreamer. Usage First follow the SD2 method for

0 Aug 10, 2021
SFD implement with pytorch

S³FD: Single Shot Scale-invariant Face Detector A PyTorch Implementation of Single Shot Scale-invariant Face Detector Description Meanwhile train hand

Jun Li 251 Dec 22, 2022
ImageNet-CoG is a benchmark for concept generalization. It provides a full evaluation framework for pre-trained visual representations which measure how well they generalize to unseen concepts.

The ImageNet-CoG Benchmark Project Website Paper (arXiv) Code repository for the ImageNet-CoG Benchmark introduced in the paper "Concept Generalizatio

NAVER 23 Oct 09, 2022
Keras community contributions

keras-contrib : Keras community contributions Keras-contrib is deprecated. Use TensorFlow Addons. The future of Keras-contrib: We're migrating to tens

Keras 1.6k Dec 21, 2022
An implementation of Fastformer: Additive Attention Can Be All You Need in TensorFlow

Fast Transformer This repo implements Fastformer: Additive Attention Can Be All You Need by Wu et al. in TensorFlow. Fast Transformer is a Transformer

Rishit Dagli 139 Dec 28, 2022
Panoptic SegFormer: Delving Deeper into Panoptic Segmentation with Transformers

Panoptic SegFormer: Delving Deeper into Panoptic Segmentation with Transformers Results results on COCO val Backbone Method Lr Schd PQ Config Download

155 Dec 20, 2022