CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

Overview

CSWin-Transformer

PWC PWC

This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". The code and models for downstream tasks are coming soon.

Introduction

CSWin Transformer (the name CSWin stands for Cross-Shaped Window) is introduced in arxiv, which is a new general-purpose backbone for computer vision. It is a hierarchical Transformer and replaces the traditional full attention with our newly proposed cross-shaped window self-attention. The cross-shaped window self-attention mechanism computes self-attention in the horizontal and vertical stripes in parallel that from a cross-shaped window, with each stripe obtained by splitting the input feature into stripes of equal width. With CSWin, we could realize global attention with a limited computation cost.

CSWin Transformer achieves strong performance on ImageNet classification (87.5 on val with only 97G flops) and ADE20K semantic segmentation (55.7 mIoU on val), surpassing previous models by a large margin.

teaser

Main Results on ImageNet

model pretrain resolution [email protected] #params FLOPs 22K model 1K model
CSWin-T ImageNet-1K 224x224 82.8 23M 4.3G - model
CSWin-S ImageNet-1k 224x224 83.6 35M 6.9G - model
CSWin-B ImageNet-1k 224x224 84.2 78M 15.0G - model
CSWin-B ImageNet-1k 384x384 85.5 78M 47.0G - model
CSWin-L ImageNet-22k 224x224 86.5 173M 31.5G model model
CSWin-L ImageNet-22k 384x384 87.5 173M 96.8G - model

Main Results on Downstream Tasks

COCO Object Detection

backbone Method pretrain lr Schd box mAP mask mAP #params FLOPS
CSwin-T Mask R-CNN ImageNet-1K 3x 49.0 43.6 42M 279G
CSwin-S Mask R-CNN ImageNet-1K 3x 50.0 44.5 54M 342G
CSwin-B Mask R-CNN ImageNet-1K 3x 50.8 44.9 97M 526G
CSwin-T Cascade Mask R-CNN ImageNet-1K 3x 52.5 45.3 80M 757G
CSwin-S Cascade Mask R-CNN ImageNet-1K 3x 53.7 46.4 92M 820G
CSwin-B Cascade Mask R-CNN ImageNet-1K 3x 53.9 46.4 135M 1004G

ADE20K Semantic Segmentation (val)

Backbone Method pretrain Crop Size Lr Schd mIoU mIoU (ms+flip) #params FLOPs
CSwin-T Semantic FPN ImageNet-1K 512x512 80K 48.2 - 26M 202G
CSwin-S Semantic FPN ImageNet-1K 512x512 80K 49.2 - 39M 271G
CSwin-B Semantic FPN ImageNet-1K 512x512 80K 49.9 - 81M 464G
CSwin-T UPerNet ImageNet-1K 512x512 160K 49.3 50.4 60M 959G
CSwin-S UperNet ImageNet-1K 512x512 160K 50.0 50.8 65M 1027G
CSwin-B UperNet ImageNet-1K 512x512 160K 50.8 51.7 109M 1222G
CSwin-B UPerNet ImageNet-22K 640x640 160K 51.8 52.6 109M 1941G
CSwin-L UperNet ImageNet-22K 640x640 160K 53.4 55.7 208M 2745G

Requirements

timm==0.3.4, pytorch>=1.4, opencv, ... , run:

bash install_req.sh

Apex for mixed precision training is used for finetuning. To install apex, run:

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Train

Train the three lite variants: CSWin-Tiny, CSWin-Small and CSWin-Base:

bash train.sh 8 --data <data path> --model CSWin_64_12211_tiny_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.2
bash train.sh 8 --data <data path> --model CSWin_64_24322_small_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.4
bash train.sh 8 --data <data path> --model CSWin_96_24322_base_224 -b 128 --lr 1e-3 --weight-decay .1 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99992 --drop-path 0.5

If you want to train our CSWin on images with 384x384 resolution, please use '--img-size 384'.

If the GPU memory is not enough, please use '-b 128 --lr 1e-3 --model-ema-decay 0.99992' or use checkpoint '--use-chk'.

Finetune

Finetune CSWin-Base with 384x384 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_96_24322_base_384 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 384 --warmup-epochs 0 --model-ema-decay 0.9998 --finetune <pretrained 224 model> --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1

Finetune ImageNet-22K pretrained CSWin-Large with 224x224 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_144_24322_large_224 -b 64 --lr 2.5e-4 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9996 --finetune <22k-pretrained model> --epochs 30 --mixup 0.01 --cooldown-epochs 10 --interpolation bicubic  --lr-scale 0.05 --drop-path 0.2 --cutmix 0.3 --use-chk --fine-22k --ema-finetune

If the GPU memory is not enough, please use checkpoint '--use-chk'.

Cite CSWin Transformer

@misc{dong2021cswin,
      title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows}, 
        author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo},
        year={2021},
        eprint={2107.00652},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
}

Acknowledgement

This repository is built using the timm library and the DeiT repository.

License

This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

Microsoft Open Source Code of Conduct

Contact Information

For help or issues using CSWin Transformer, please submit a GitHub issue.

For other communications related to CSWin Transformer, please contact Jianmin Bao ([email protected]), Dong Chen ([email protected]).

Comments
  • About the patches_resolution of the segmentation model

    About the patches_resolution of the segmentation model

    Hello, this work is interesting but I have some questions about the 'patches_resolution' of the segmentation model. I notice that the long side of the cross-shaped windows is the 'patches_resolution' rather than the real feature resoulution. For example, in the stage-3, the long side is 224 / 16 = 14. Do I understand it correctly? Does that make it impossible to exchannge information outside the 'patches_resolution' ?

    opened by danczs 2
  • The results of downstream task by my realization are poor

    The results of downstream task by my realization are poor

    There are some questions:

    1. the split size is still [1 2 7 7]?
    2. last stage branch_num is 2 or 1 ? The downstream task image resolution in last stages cannot equal to 7(split size). If not 1, the pretrained weights size is not matched
    3. pading is right in my realization ? pad_l = pad_t = 0 pad_r = (W_sp - W % W_sp) % W_sp pad_b = (H_sp - H % H_sp) % H_sp q = q.transpose(-2,-1).contiguous().view(B, H, W, C) k = q.transpose(-2,-1).contiguous().view(B, H, W, C) v = q.transpose(-2,-1).contiguous().view(B, H, W, C) if pad_r > 0 or pad_b > 0: q = F.pad(q, (0, 0, pad_l, pad_r, pad_t, pad_b)) k = F.pad(k, (0, 0, pad_l, pad_r, pad_t, pad_b)) v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = q.shape
    opened by Sunting78 2
  • Experiment setting for semantic segmentation

    Experiment setting for semantic segmentation

    Hi, thank you for the code. I implemented CSwin-T with FPN for semantic segmentation in ADE20K but couldn't reach the mIoU value of 48.2 as mentioned by you in the table. The maximum I could get was 39.9 mIoU, it will be great if you could share the exact experiment settings you used? Thanks

    opened by AnukritiSinghh 2
  • CSWin significantly slower than Swin?

    CSWin significantly slower than Swin?

    Greetings,

    From my benchmarks I have noticed that CSwin seems to be significantly slower than Swin when it comes to inference times, is this the expected behavior? While I can get predictions as fast as 20 miliseconds on Swin Large 384 it takes above 900 milisecond on CSWin_144_24322_large_384.

    I performed tests using FP16, torchscript, optimize_for_inference and torch.inference_mode

    opened by ErenBalatkan 2
  • Pretrained settings for object detection

    Pretrained settings for object detection

    Hi, I'm impressed by your excellent work.

    I have a question.

    I wonder which type of the pre-trained weights (224x224 or 384x384 finetuned) is used for object detection.

    I know both 224x224 and 384x384 are pre-trained on ImageNet-1k.

    opened by youngwanLEE 2
  • Error about building 384 models

    Error about building 384 models

    The code: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L407 shoud be:

    model = CSWinTransformer(img_size=384, patch_size=4, embed_dim=96, depth=[2,4,32,2],
    

    And as the same: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L414

    opened by TingquanGao 2
  • The problem is shown in the figure

    The problem is shown in the figure

    model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2], split_size=[1,2,12,12], num_heads=[4,8,16,32], mlp_ratio=4.).cuda().eval() inp = torch.rand(1, 3, 224, 224).cuda() outs = model(inp) for out in outs: print(out.shape)

    RuntimeError: shape '[1, 192, 1, 14, 1, 12]' is invalid for input of size 37632

    why? image

    opened by rui-cf 2
  • about the test result on imagenet-2012

    about the test result on imagenet-2012

    Hi! I've test the CSwin-Tiny-224 released pretrained weight, this is my data transforms during testing:

    DEFAULT_CROP_SIZE = 0.9
    scale_size = int(math.floor(image_size / DEFAULT_CROP_SIZE))
    transform = transforms.Compose(
            [
                transforms.Resize(scale_size, interpolation=3)  # 3: bibubic
                if image_size == 224
                else transforms.Resize(image_size, interpolation=3),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
            ]
        )
    

    I can only get 80.5% on imagenet2012 dataset which is inconsistent with the results as you mentioned in this repo, did I miss some details about the data-augmentation during testing?

    opened by rentainhe 1
  • CSwin Code for Segmentation with MMSegmentation

    CSwin Code for Segmentation with MMSegmentation

    Hi, Thank you for your work! I wonder if you plan to release the mmsegmentation code you used for the downstream segmentation task, just like in Swin-Transformer repo.

    Best,

    opened by WalBouss 1
  • How do you produce Table 9  (ablation on different attention mecahnisms) in the  paper?

    How do you produce Table 9 (ablation on different attention mecahnisms) in the paper?

    Hi, thanks for your nice work. I'm doing some comparison on different attention mechanisms, and want to follow your experimental settings. I meet two problems:

    1. Why the reported mIoU is 41.9 for Swin-T in Table 9, while it is 46.1 in Swin Paper?
    2. Can you provide detailed experimental settings for semantic segmentation and object detection in table 9 ?
    opened by rayleizhu 0
  • about the setting of --use_chk

    about the setting of --use_chk

    give the parameter of --use_chk can launch the torch.utils.checkpoint to save the GPU memory, and I wonder if this could hurt the final performance, thanks a lot!

    opened by go-ahead-maker 0
  • Input image normalization parameters for semantic segmentation.

    Input image normalization parameters for semantic segmentation.

    Hey, guys, cool work!

    Unfortunately, for me it is not quite obvious, which normalization parameters (mean, std) did you use, when trained semantic segmentation model on ADE20K. Are they still IMAGENET_DEFAULT_MEAN or you used another values?

    opened by NikAleksFed 0
  • Using transfer to train over food101

    Using transfer to train over food101

    Hi all! I'm trying to train a model for food101 using the using the CSWin_64_12211_tiny_224 model with its pretrained values. The thing is, during execution it looks like its training from 0 rather than reusing the pretrained weights. By this I mean the initial top5 accuracy is around 5% but my initial thoughts is that it should be higher than this.

    For this I loaded the pretrained model and changed it's classification layer in a separate script and saved it for use as follows

    model = create_model( 'CSWin_64_12211_tiny_224', pretrained=True, num_classes=1000, drop_rate=0.0, drop_connect_rate=None, # DEPRECATED, use drop_path drop_path_rate=0.2, drop_block_rate=None, global_pool=None, bn_tf=False, bn_momentum=None, bn_eps=None, checkpoint_path='', img_size=224, use_chk=True) chk_path = './pretrained/cswin_tiny_224.pth' load_checkpoint(model, chk_path) model.reset_classifier(101, 'max')

    These are some of the runs I tried ` bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1 --use-chk --num-classes 101 --pretrained --finetune ./pretrained/CSWin_64_12211_tiny_224101.pth

    bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --cooldown-epochs 10 --drop-path 0.2 --ema-finetune --cutmix 0.1 --use-chk --num-classes 101 --initial-checkpoint ./pretrained/CSWin_64_12211_tiny_224101.pth --lr-scale 1.0 --output ./full_base `

    Is there something I'm missing or a proper way I should try this?

    Thanks in advance for any help! :)

    opened by andreynz691 0
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Selective Wavelet Attention Learning for Single Image Deraining

SWAL Code for Paper "Selective Wavelet Attention Learning for Single Image Deraining" Prerequisites Python 3 PyTorch Models We provide the models trai

Bobo 9 Jun 17, 2022
🔊 Audio and fastai v2

Fastaudio An audio module for fastai v2. We want to help you build audio machine learning applications while minimizing the need for audio domain expe

152 Dec 28, 2022
Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets).

TOQ-Nets-PyTorch-Release Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets). Temporal and Object Quantification Net

Zhezheng Luo 9 Jun 30, 2022
Joint Detection and Identification Feature Learning for Person Search

Person Search Project This repository hosts the code for our paper Joint Detection and Identification Feature Learning for Person Search. The code is

712 Dec 17, 2022
[Open Source]. The improved version of AnimeGAN. Landscape photos/videos to anime

[Open Source]. The improved version of AnimeGAN. Landscape photos/videos to anime

CC 4.4k Dec 27, 2022
Synthesizing Long-Term 3D Human Motion and Interaction in 3D in CVPR2021

Long-term-Motion-in-3D-Scenes This is an implementation of the CVPR'21 paper "Synthesizing Long-Term 3D Human Motion and Interaction in 3D". Please ch

Jiashun Wang 76 Dec 13, 2022
Vector AI — A platform for building vector based applications. Encode, query and analyse data using vectors.

Vector AI is a framework designed to make the process of building production grade vector based applications as quickly and easily as possible. Create

Vector AI 267 Dec 23, 2022
PyTorch implementation of Deformable Convolution

Deformable Convolutional Networks in PyTorch This repo is an implementation of Deformable Convolution. Ported from author's MXNet implementation. Buil

411 Dec 16, 2022
Official repository of DeMFI (arXiv.)

DeMFI This is the official repository of DeMFI (Deep Joint Deblurring and Multi-Frame Interpolation). [ArXiv_ver.] Coming Soon. Reference Jihyong Oh a

Jihyong Oh 56 Dec 14, 2022
[ACM MM 2021] Multiview Detection with Shadow Transformer (and View-Coherent Data Augmentation)

Multiview Detection with Shadow Transformer (and View-Coherent Data Augmentation) [arXiv] [paper] @inproceedings{hou2021multiview, title={Multiview

Yunzhong Hou 27 Dec 13, 2022
CRISCE: Automatically Generating Critical Driving Scenarios From Car Accident Sketches

CRISCE: Automatically Generating Critical Driving Scenarios From Car Accident Sketches This document describes how to install and use CRISCE (CRItical

Chair of Software Engineering II, Uni Passau 2 Feb 09, 2022
RoboDesk A Multi-Task Reinforcement Learning Benchmark

RoboDesk A Multi-Task Reinforcement Learning Benchmark If you find this open source release useful, please reference in your paper: @misc{kannan2021ro

Google Research 66 Oct 07, 2022
Tools for robust generative diffeomorphic slice to volume reconstruction

RGDSVR Tools for Robust Generative Diffeomorphic Slice to Volume Reconstructions (RGDSVR) This repository provides tools to implement the methods in t

Lucilio Cordero-Grande 0 Oct 29, 2021
SSD: Single Shot MultiBox Detector pytorch implementation focusing on simplicity

SSD: Single Shot MultiBox Detector Introduction Here is my pytorch implementation of 2 models: SSD-Resnet50 and SSDLite-MobilenetV2.

Viet Nguyen 149 Jan 07, 2023
Introducing neural networks to predict stock prices

IntroNeuralNetworks in Python: A Template Project IntroNeuralNetworks is a project that introduces neural networks and illustrates an example of how o

Vivek Palaniappan 637 Jan 04, 2023
v objective diffusion inference code for PyTorch.

v-diffusion-pytorch v objective diffusion inference code for PyTorch, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The

Katherine Crowson 635 Dec 30, 2022
bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED)

osed-scripts bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED) Table of Contents Standalone Scripts egghunter.py fin

epi 268 Jan 05, 2023
CV backbones including GhostNet, TinyNet and TNT, developed by Huawei Noah's Ark Lab.

CV Backbones including GhostNet, TinyNet, TNT (Transformer in Transformer) developed by Huawei Noah's Ark Lab. GhostNet Code TinyNet Code TNT Code Pyr

HUAWEI Noah's Ark Lab 3k Jan 08, 2023
GAN Image Generator and Characterwise Image Recognizer with python

MODEL SUMMARY 모델의 구조는 크게 6단계로 나뉩니다. STEP 0: Input Image Predict 할 이미지를 모델에 입력합니다. STEP 1: Make Black and White Image STEP 1 은 입력받은 이미지의 글자를 흑색으로, 배경을

Juwan HAN 1 Feb 09, 2022
LAVT: Language-Aware Vision Transformer for Referring Image Segmentation

LAVT: Language-Aware Vision Transformer for Referring Image Segmentation Where we are ? 12.27 目前和原论文仍有1%左右得差距,但已经力压很多SOTA了 ckpt__448_epoch_25.pth mIoU

zichengsaber 60 Dec 11, 2022