Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Overview

Swin-Transformer-Tensorflow

A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" to TensorFlow 2.

The official Pytorch implementation can be found here.

Introduction:

Swin Transformer Architecture Diagram

Swin Transformer (the name Swin stands for Shifted window) is initially described in arxiv, which capably serves as a general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.

Swin Transformer achieves strong performance on COCO object detection (58.7 box AP and 51.1 mask AP on test-dev) and ADE20K semantic segmentation (53.5 mIoU on val), surpassing previous models by a large margin.

Usage:

1. To Run a Pre-trained Swin Transformer

Swin-T:

python main.py --cfg configs/swin_tiny_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-S:

python main.py --cfg configs/swin_small_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-B:

python main.py --cfg configs/swin_base_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

The possible options for cfg and weights_type are:

cfg weights_type 22K model 1K Model
configs/swin_tiny_patch4_window7_224.yaml imagenet_1k - github
configs/swin_small_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22k github -
configs/swin_base_patch4_window12_384.yaml imagenet_22k github -
configs/swin_large_patch4_window7_224.yaml imagenet_22k github -
configs/swin_large_patch4_window12_384.yaml imagenet_22k github -

2. Create custom models

To create a custom classification model:

import argparse

import tensorflow as tf

from config import get_config
from models.build import build_model

parser = argparse.ArgumentParser('Custom Swin Transformer')

parser.add_argument(
    '--cfg',
    type=str,
    metavar="FILE",
    help='path to config file',
    default="CUSTOM_YAML_FILE_PATH"
)
parser.add_argument(
    '--resume',
    type=int,
    help='Whether or not to resume training from pretrained weights',
    choices={0, 1},
    default=1,
)
parser.add_argument(
    '--weights_type',
    type=str,
    help='Type of pretrained weight file to load including number of classes',
    choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"},
    default="imagenet_1k",
)

args = parser.parse_args()
custom_config = get_config(args, include_top=False)

swin_transformer = tf.keras.Sequential([
    build_model(config=custom_config, load_pretrained=args.resume, weights_type=args.weights_type),
    tf.keras.layers.Dense(CUSTOM_NUM_CLASSES)
)

Model ouputs are logits, so don't forget to include softmax in training/inference!!

You can easily customize the model configs with custom YAML files. Predefined YAML files provided by Microsoft are located in the configs directory.

3. Convert PyTorch pretrained weights into Tensorflow checkpoints

We provide a python script with which we convert official PyTorch weights into Tensorflow checkpoints.

$ python convert_weights.py --cfg config_file --weights the_path_to_pytorch_weights --weights_type type_of_pretrained_weights --output the_path_to_output_tf_weights

TODO:

  • Translate model code over to TensorFlow
  • Load PyTorch pretrained weights into TensorFlow model
  • Write trainer code
  • Reproduce results presented in paper
    • Object Detection
  • Reproduce training efficiency of official code in TensorFlow

Citations:

@misc{liu2021swin,
      title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 
      author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
      year={2021},
      eprint={2103.14030},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

Non-Official Pytorch implementation of
Non-Official Pytorch implementation of "Face Identity Disentanglement via Latent Space Mapping" https://arxiv.org/abs/2005.07728 Using StyleGAN2 instead of StyleGAN

Face Identity Disentanglement via Latent Space Mapping - Implement in pytorch with StyleGAN 2 Description Pytorch implementation of the paper Face Ide

Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)
Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)

Transfer Learning for Text Classification with Tensorflow Tensorflow implementation of Semi-supervised Sequence Learning(https://arxiv.org/abs/1511.01

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Comments
  • Custom Swin Transformer: error: unrecognized arguments

    Custom Swin Transformer: error: unrecognized arguments

    parser = argparse.ArgumentParser('Custom Swin Transformer')

    parser.add_argument( '--cfg', type=str, metavar="FILE", help='/content/Swin-Transformer-Tensorflow/configs/swin_tiny_patch4_window7_224.yaml', default="CUSTOM_YAML_FILE_PATH" ) parser.add_argument( '--resume', type=int, help=1, choices={0, 1}, default=1, ) parser.add_argument( '--weights_type', type=str, help='imagenet_22k', choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"}, default="imagenet_1k", )

    args = parser.parse_args() custom_config = get_config(args, include_top=False)

    i am trying to use it but it throws an error below

    usage: Custom Swin Transformer [-h] [--cfg FILE] [--resume {0,1}] [--weights_type {imagenet_22kto1k,imagenet_1k,imagenet_22k}] Custom Swin Transformer: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-ee309a98-1f20-4bb7-aa12-c2980aea076c.json An exception has occurred, use %tb to see the full traceback.

    SystemExit: 2

    opened by AliKayhanAtay 1
  • train dataset

    train dataset

    Thank you for Thank you for providing your code. I've been running the pretrained model, and I'd like to know how to learn about custom data from the code you provided and how to transfer learning to custom data using the pretrained model. Thank you.

    opened by hoyeoung 1
This is a Keras-based Python implementation of DeepMask- a complex deep neural network for learning object segmentation masks

NNProject - DeepMask This is a Keras-based Python implementation of DeepMask- a complex deep neural network for learning object segmentation masks. Th

189 Nov 16, 2022
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

Abdultawwab Safarji 7 Nov 27, 2022
Liecasadi - liecasadi implements Lie groups operation written in CasADi

liecasadi liecasadi implements Lie groups operation written in CasADi, mainly di

Artificial and Mechanical Intelligence 14 Nov 05, 2022
A multi-scale unsupervised learning for deformable image registration

A multi-scale unsupervised learning for deformable image registration Shuwei Shao, Zhongcai Pei, Weihai Chen, Wentao Zhu, Xingming Wu and Baochang Zha

ShuweiShao 2 Apr 13, 2022
[EMNLP 2021] Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training

RoSTER The source code used for Distantly-Supervised Named Entity Recognition with Noise-Robust Learning and Language Model Augmented Self-Training, p

Yu Meng 60 Dec 30, 2022
Framework for training options with different attention mechanism and using them to solve downstream tasks.

Using Attention in HRL Framework for training options with different attention mechanism and using them to solve downstream tasks. Requirements GPU re

5 Nov 03, 2022
RANZCR-CLiP 7th Place Solution

RANZCR-CLiP 7th Place Solution This repository is WIP. (18 Mar 2021) Installation git clone https://github.com/analokmaus/kaggle-ranzcr-clip-public.gi

Hiroshechka Y 21 Oct 22, 2022
Official code for "End-to-End Optimization of Scene Layout" -- including VAE, Diff Render, SPADE for colorization (CVPR 2020 Oral)

End-to-End Optimization of Scene Layout Code release for: End-to-End Optimization of Scene Layout CVPR 2020 (Oral) Project site, Bibtex For help conta

Andrew Luo 41 Dec 09, 2022
Code for generating the figures in the paper "Capacity of Group-invariant Linear Readouts from Equivariant Representations: How Many Objects can be Linearly Classified Under All Possible Views?"

Code for running simulations for the paper "Capacity of Group-invariant Linear Readouts from Equivariant Representations: How Many Objects can be Lin

Matthew Farrell 1 Nov 22, 2022
A curated list of awesome Deep Learning tutorials, projects and communities.

Awesome Deep Learning Table of Contents Books Courses Videos and Lectures Papers Tutorials Researchers Websites Datasets Conferences Frameworks Tools

Christos 20k Jan 05, 2023
Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Nikhil Iyer 9 Dec 27, 2022
ML course - EPFL Machine Learning Course, Fall 2021

EPFL Machine Learning Course CS-433 Machine Learning Course, Fall 2021 Repository for all lecture notes, labs and projects - resources, code templates

EPFL Machine Learning and Optimization Laboratory 1k Jan 04, 2023
Deep deconfounded recommender (Deep-Deconf) for paper "Deep causal reasoning for recommendations"

Deep Causal Reasoning for Recommender Systems The codes are associated with the following paper: Deep Causal Reasoning for Recommendations, Yaochen Zh

Yaochen Zhu 22 Oct 15, 2022
Ladder Variational Autoencoders (LVAE) in PyTorch

Ladder Variational Autoencoders (LVAE) PyTorch implementation of Ladder Variational Autoencoders (LVAE) [1]: where the variational distributions q at

Andrea Dittadi 63 Dec 22, 2022
A voice recognition assistant similar to amazon alexa, siri and google assistant.

kenyan-Siri Build an Artificial Assistant Full tutorial (video) To watch the tutorial, click on the image below Installation For windows users (run th

Alison Parker 3 Aug 19, 2022
Application of K-means algorithm on a music dataset after a dimensionality reduction with PCA

PCA for dimensionality reduction combined with Kmeans Goal The Goal of this notebook is to apply a dimensionality reduction on a big dataset in order

Arturo Ghinassi 0 Sep 17, 2022
Image-to-image translation with conditional adversarial nets

pix2pix Project | Arxiv | PyTorch Torch implementation for learning a mapping from input images to output images, for example: Image-to-Image Translat

Phillip Isola 9.3k Jan 08, 2023
LexGLUE: A Benchmark Dataset for Legal Language Understanding in English

LexGLUE: A Benchmark Dataset for Legal Language Understanding in English ⚖️ 🏆 🧑‍🎓 👩‍⚖️ Dataset Summary Inspired by the recent widespread use of th

95 Dec 08, 2022
Simple PyTorch implementations of Badnets on MNIST and CIFAR10.

Simple PyTorch implementations of Badnets on MNIST and CIFAR10.

Vera 75 Dec 13, 2022
An Unpaired Sketch-to-Photo Translation Model

Unpaired-Sketch-to-Photo-Translation We have released our code at https://github.com/rt219/Unsupervised-Sketch-to-Photo-Synthesis This project is the

38 Oct 28, 2022