Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

Overview

CrossViT

This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv

If you use the codes and models from this repo, please cite our work. Thanks!

@inproceedings{
    chen2021crossvit,
    title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
    author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
    booktitle={International Conference on Computer Vision (ICCV)},
    year={2021}
}

Installation

To install requirements:

pip install -r requirements.txt

With conda:

conda create -n crossvit python=3.8
conda activate crossvit
conda install pytorch=1.7.1 torchvision  cudatoolkit=11.0 -c pytorch -c nvidia
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Pretrained models

We provide models trained on ImageNet1K. You can find models here. And you can load pretrained weights into models by add --pretrained flag.

Training

To train crossvit_9_dagger_224 on ImageNet on a single node with 8 gpus for 300 epochs run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Other model names can be found at models/crossvit.py.

Multinode training

Distributed training is available via Slurm and submitit:

To train a crossvit_9_dagger_224 model on ImageNet on 4 nodes with 8 gpus each for 300 epochs:

python run_with_submitit.py --nodes 4 --model crossvit_9_dagger_224 --data-path /path/to/imagenet --batch-size 128 --warmup-epochs 30

Or you can start process on each machine maunally. E.g. 2 nodes, each with 8 gpus.

Machine 0:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=0 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Machine 1:

python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=1 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet

Note that: some slurm configurations might need to be changed based on your cluster.

Evaluation

To evaluate a pretrained model on crossvit_9_dagger_224:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 128 --data-path /path/to/imagenet --eval --pretrained
You might also like...
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)
Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)

MUSIQ: Multi-Scale Image Quality Transformer Unofficial pytorch implementation of the paper "MUSIQ: Multi-Scale Image Quality Transformer" (paper link

Official implementation of
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Official implementation of paper
Official implementation of paper "Query2Label: A Simple Transformer Way to Multi-Label Classification".

Introdunction This is the official implementation of the paper "Query2Label: A Simple Transformer Way to Multi-Label Classification". Abstract This pa

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Comments
  • Multilabel classification

    Multilabel classification

    @jjasghar @krook @chunfuchen thanks a lot for sharing the code , i have a problem statement which i want to train crossvit on please find the details below

    1. input is of image shape 256x128 having label vector as gt for multilabel classification
    2. can i use crossvit to train the model , what all modification has to be done with the code base ??

    THnaks for the support

    opened by abhigoku10 4
  • Honoring distributed flag + fixing reset_classifier

    Honoring distributed flag + fixing reset_classifier

    1. Honoring the args.distributed flag in calls to evaluate().
    2. A couple of changes to make the reset_classifier() method work:
    • Initializing the embed_dim instance variable in VisionTransformer.
    • Reinitializing the classification head for all branches.
    opened by abhrac 1
  • Parameter setting

    Parameter setting

    Hello, thank you for your excellent work, I would like to know how you set the parameters on the CIFAR10 dataset, mainly the size of the patch,Looking forward to your reply

    opened by happy20200 1
Owner
International Business Machines
International Business Machines
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

49 Nov 23, 2022
“英特尔创新大师杯”深度学习挑战赛 赛道3:CCKS2021中文NLP地址相关性任务

ccks2021-track3 CCKS2021中文NLP地址相关性任务-赛道三-冠军方案 团队:我的加菲鱼- wodejiafeiyu 初赛第二/复赛第一/决赛第一 前言 19年开始,陆陆续续参加了一些比赛,拿到过一些top,比较懒一直都没分享过,这次比较幸运又拿了top1,打算分享下 分类的任务

shaochenjie 131 Dec 31, 2022
PyTorch implementation of paper “Unbiased Scene Graph Generation from Biased Training”

A new codebase for popular Scene Graph Generation methods (2020). Visualization & Scene Graph Extraction on custom images/datasets are provided. It's also a PyTorch implementation of paper “Unbiased

Kaihua Tang 824 Jan 03, 2023
Official implementation for the paper: Permutation Invariant Graph Generation via Score-Based Generative Modeling

Permutation Invariant Graph Generation via Score-Based Generative Modeling This repo contains the official implementation for the paper Permutation In

64 Dec 29, 2022
A simple consistency training framework for semi-supervised image semantic segmentation

PseudoSeg: Designing Pseudo Labels for Semantic Segmentation PseudoSeg is a simple consistency training framework for semi-supervised image semantic s

Google Interns 143 Dec 13, 2022
ktrain is a Python library that makes deep learning and AI more accessible and easier to apply

Overview | Tutorials | Examples | Installation | FAQ | How to Cite Welcome to ktrain News and Announcements 2020-11-08: ktrain v0.25.x is released and

Arun S. Maiya 1.1k Jan 02, 2023
Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

TianYuan 27 Nov 07, 2022
This repo contains source code and materials for the TEmporally COherent GAN SIGGRAPH project.

TecoGAN This repository contains source code and materials for the TecoGAN project, i.e. code for a TEmporally COherent GAN for video super-resolution

Nils Thuerey 5.2k Jan 02, 2023
Neural Ensemble Search for Performant and Calibrated Predictions

Neural Ensemble Search Introduction This repo contains the code accompanying the paper: Neural Ensemble Search for Performant and Calibrated Predictio

AutoML-Freiburg-Hannover 26 Dec 12, 2022
Pomodoro timer that acknowledges the inexorable, infinite passage of time

Pomodouroboros Most pomodoro trackers assume you're going to start them. But time and tide wait for no one - the great pomodoro of the cosmos is cold

Glyph 66 Dec 13, 2022
Official PyTorch implementation of "Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble" (NeurIPS'21)

Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble This is the code for reproducing the results of the paper Uncertainty-Bas

43 Nov 23, 2022
Layered Neural Atlases for Consistent Video Editing

Layered Neural Atlases for Consistent Video Editing Project Page | Paper This repository contains an implementation for the SIGGRAPH Asia 2021 paper L

Yoni Kasten 353 Dec 27, 2022
Bio-Computing Platform Featuring Large-Scale Representation Learning and Multi-Task Deep Learning “螺旋桨”生物计算工具集

English | 简体中文 Latest News 2021.10.25 Paper "Docking-based Virtual Screening with Multi-Task Learning" is accepted by BIBM 2021. 2021.07.29 PaddleHeli

633 Jan 04, 2023
A benchmark for the task of translation suggestion

WeTS: A Benchmark for Translation Suggestion Translation Suggestion (TS), which provides alternatives for specific words or phrases given the entire d

zhyang 55 Dec 24, 2022
TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Meta Research 663 Jan 06, 2023
Telegram chatbot created with deep learning model (LSTM) and telebot library.

Telegram chatbot Telegram chatbot created with deep learning model (LSTM) and telebot library. Description This program will allow you to create very

1 Jan 04, 2022
A Python Package for Portfolio Optimization using the Critical Line Algorithm

PyCLA A Python Package for Portfolio Optimization using the Critical Line Algorithm Getting started To use PyCLA, clone the repo and install the requi

19 Oct 11, 2022
An Open Source Machine Learning Framework for Everyone

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

170.1k Jan 04, 2023
Edge-aware Guidance Fusion Network for RGB-Thermal Scene Parsing

EGFNet Edge-aware Guidance Fusion Network for RGB-Thermal Scene Parsing Dataset and Results Test maps: 百度网盘 提取码:zust Citation @ARTICLE{ author={Zhou,

ShaohuaDong 10 Dec 08, 2022
BASH - Biomechanical Animated Skinned Human

We developed a method animating a statistical 3D human model for biomechanical analysis to increase accessibility for non-experts, like patients, athletes, or designers.

Machine Learning and Data Analytics Lab FAU 66 Nov 19, 2022