Pytorch implementation of "Training a 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet"

Overview

Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet (arxiv)

This is a Pytorch implementation of our technical report.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Training Pipeline

Pipeline

Our codes are based on the pytorch-image-models by Ross Wightman.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-L 24 768 448 150.47M 86.2 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map here. As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Coming soon

Reference

If you use this repo or find it useful, please consider citing:

@misc{jiang2021token,
      title={Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet}, 
      author={Zihang Jiang and Qibin Hou and Li Yuan and Daquan Zhou and Xiaojie Jin and Anran Wang and Jiashi Feng},
      year={2021},
      eprint={2104.10858},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Related projects

T2T-ViT, Re-labeling ImageNet.

Comments
  • error: download the pretrained model but couldn't be unzipped

    error: download the pretrained model but couldn't be unzipped

    tar -xvf lvvit_s-26M-384-84-4.pth.tar tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    opened by Williamlizl 10
  • The accuracy of the validation set is 0,and the loss is always around 13

    The accuracy of the validation set is 0,and the loss is always around 13

    Hello! I use ILSVRC2012_img_train and ILSVRC2012_img_val, and use the provided label_top5_train_nfnet from Google Drive. I train lv-vit-s with batch_size 64 without apex for one epoch. Thanks for your advice.

    opened by yifanQi98 7
  • Pretrained weights for LV-ViT-T

    Pretrained weights for LV-ViT-T

    Hi,

    Thanks for sharing your work. Could you also provide the pre-trained weights for the LV-ViT-T model variant, the one that achieves 79.1% top1-acc. as mentioned in Table 1 of your paper?

    All the best, Marc

    opened by marc345 5
  • train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    Hi, thanks for you great work. When I train script, some error occurs: AttributeError: 'tuple' object has no attribute 'log_softmax'

    with amp_autocast():   
                output = model(input)  
                loss = loss_fn(output, target)  # error occurs
    
    

    and loss function is train_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.0).cuda()

    by the way: Could you please tell me why we need to specify smoothing=0.0?

    opened by lxy5513 5
  • RuntimeError: CUDA error: device-side assert triggered

    RuntimeError: CUDA error: device-side assert triggered

    I am a green hand of DL. When I run the code of volo with tlt in a single or multi GPU, I get an error as follows: /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [25,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. Traceback (most recent call last): File "main.py", line 949, in main() File "main.py", line 664, in main optimizers=optimizers) File "main.py", line 773, in train_one_epoch label_size=args.token_label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 90, in mixup_target y1 = get_labelmaps_with_coords(target, num_classes, on_value=on_value, off_value=off_value, device=device, label_size=label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 64, in get_labelmaps_with_coords num_classes=num_classes,device=device) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 16, in get_featuremaps _label_topk[1][:, :, :].long(), RuntimeError: CUDA error: device-side assert triggered.

    I can't fix this problem right now.

    opened by JIAOJIAYUASD 4
  • Generating label for custom dataset

    Generating label for custom dataset

    Hello,

    Thank you for sharing your work. I am currently trying to generate token label to a custom dataset for model lvvit_s, but I keep getting the loss close to 7 and the Accuracy 0 (not pre-trained and using 1 GPU in Google Colab). I also tried using the pre-trained model with --transfer but got 0 in both Loss and Acc . What option should I use for a custom dataset? image

    opened by AleMaiaF 2
  • generate_label.py unable to find model lvvit_s

    generate_label.py unable to find model lvvit_s

    Hi,

    When I tried to run the label generation script for the model lvvit_s it returned an error "RuntimeError: Unknown model".

    Solution: It worked when I added the line "import tlt.models" in the file generate_label.py.

    opened by AleMaiaF 2
  • Can Token labeling reach higher than annotator model?

    Can Token labeling reach higher than annotator model?

    Greetings,

    Thank you for this incredible research.

    I would like to know if it is possible to use Token Labeling to achieve scores higher than that of the annotator model, I believe this was the case with VOLO D5 model where it achieved higher score than NFNet, model used for annotation.

    opened by ErenBalatkan 1
  • label_map does not do the same augmentation (random crop) as the input image

    label_map does not do the same augmentation (random crop) as the input image

    Hi Thanks so much for the nice work! I am curious if you could share the insight on processing of the label_map. If I understand it correctly, after we load image and the corresponding, we shall do the same cropping/ flip/ resize, but in https://github.com/zihangJiang/TokenLabeling/blob/aa438eff9b9fc2daa8c8b4cc6bfaa6e3721f995e/tlt/data/label_transforms_factory.py#L58-L73 Seems only image was cropped, but the label map does not do the same cropping, which make the label map not match with the image?

    Shall we do

            return torchvision_F.resized_crop(
                    img, i, j, h, w, self.size, interpolation
            ), torchvision_F.resized_crop(
                    label_map, i / ratio, j / ratio, h / ratio, w / ratio, self.size, interpolation
            )
    

    Thanks

    opened by haooooooqi 1
  • Python3.6, ok; Python3.8, error

    Python3.6, ok; Python3.8, error

    Test: [ 0/1] Time: 11.293 (11.293) Loss: 0.7043 (0.7043) [email protected]: 42.1875 (42.1875) [email protected]: 100.0000 (100.0000) Test: [ 1/1] Time: 0.108 (5.701) Loss: 0.5847 (0.6689) [email protected]: 89.8148 (56.3187) [email protected]: 100.0000 (100.0000) free(): invalid pointer free(): invalid pointer Traceback (most recent call last): File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 303, in <module> main() File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 294, in main raise subprocess.CalledProcessError(returncode=process.returncode, subprocess.CalledProcessError: Command '['/opt/conda/bin/python3.8', '-u', 'main.py', '--local_rank=1', './dataset/c/c', '--model', 'lvvit_s', '-b', '128', '--apex-amp', '--img-size', '224', '--drop-path', '0.1', '--token-label', '--token-label-size', '14', '--dense-weight', '0.0', '--num-classes', '2', '--finetune', './pretrained/lvvit_s-26M-384-84-4.pth.tar']' died with <Signals.SIGABRT: 6>. [email protected]:/puxin_libochao/TokenLabeling# CUDA_VISIBLE_DEVICES=0,1 bash ./distributed_train.sh 2 ./dataset/c/c --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes 2 --finetune ./pretrained/lvvit_s-26M-384-84-4.pth.tar

    opened by Williamlizl 1
  • A Bag of Training Techniques for ViT

    A Bag of Training Techniques for ViT

    Hi, thanks for your wonderful work. I have a question that whether training techniques mentioned in the LV-Vit can be used in other downstream task like object detection? In your paper, I see that many of this techniques are used in ImageNet. Thanks!

    opened by qdd1234 1
  • how to apply token labeling to CNN ?

    how to apply token labeling to CNN ?

    Hello ~ I'm interested in your token labeling technique, So I want to apply this technique in CNN based model because ViT is very heavy to train.

    can I get the your code with CNN token labeling? if you're not give me some detail for implementing

    thank you.

    opened by HoJ00n2 0
  • Model settings for Cifar10

    Model settings for Cifar10

    I am interested if there is any LV-ViT- model setup you have tested for Cifar10. I would like to know the proper setup of all blocks in none pretrained weights settings.

    opened by Aminullah6264 0
Owner
蒋子航
Now a Ph.D. student supervised by Prof. Feng Jiashi in ECE, NUS.
蒋子航
Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding (CVPR2022)

Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding by Qiaole Dong*, Chenjie Cao*, Yanwei Fu Paper and Supple

Qiaole Dong 190 Dec 27, 2022
This repository implements variational graph auto encoder by Thomas Kipf.

Variational Graph Auto-encoder in Pytorch This repository implements variational graph auto-encoder by Thomas Kipf. For details of the model, refer to

DaehanKim 215 Jan 02, 2023
NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

880 Jan 07, 2023
SiT: Self-supervised vIsion Transformer

This repository contains the official PyTorch self-supervised pretraining, finetuning, and evaluation codes for SiT (Self-supervised image Transformer).

Sara Ahmed 275 Dec 28, 2022
Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks

Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks Work accepted at NeurIPS'21 [paper, video]. If you use this code in

TU Delft 43 Dec 07, 2022
yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)

代码地址: https://github.com/Sharpiless/yolov5-knowledge-distillation 教师模型: python train.py --weights weights/yolov5m.pt \ --cfg models/yolov5m.ya

52 Dec 04, 2022
Attempt at implementation of a simple GAN using Keras

Simple GAN This is my attempt to make a wrapper class for a GAN in keras which can be used to abstract the whole architecture process. Simple GAN Over

Deven96 7 May 23, 2019
Experiments with Fourier layers on simulation data.

Factorized Fourier Neural Operators This repository contains the code to reproduce the results in our NeurIPS 2021 ML4PS workshop paper, Factorized Fo

Alasdair Tran 57 Dec 25, 2022
KE-Dialogue: Injecting knowledge graph into a fully end-to-end dialogue system.

Learning Knowledge Bases with Parameters for Task-Oriented Dialogue Systems This is the implementation of the paper: Learning Knowledge Bases with Par

CAiRE 42 Nov 10, 2022
Pytorch code for ICRA'21 paper: "Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation"

Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation This repository is the pytorch implementation of our paper: Hierarchical Cr

43 Nov 21, 2022
An implementation for the loss function proposed in Decoupled Contrastive Loss paper.

Decoupled-Contrastive-Learning This repository is an implementation for the loss function proposed in Decoupled Contrastive Loss paper. Requirements P

Ramin Nakhli 71 Dec 04, 2022
End-to-end speech secognition toolkit

End-to-end speech secognition toolkit This is an E2E ASR toolkit modified from Espnet1 (version 0.9.9). This is the official implementation of paper:

Jinchuan Tian 147 Dec 28, 2022
Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN"

Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtu

68 Dec 21, 2022
Qt-GUI implementation of the YOLOv5 algorithm (ver.6 and ver.5)

YOLOv5-GUI 🎉 YOLOv5算法(ver.6及ver.5)的Qt-GUI实现 🎉 Qt-GUI implementation of the YOLOv5 algorithm (ver.6 and ver.5). 基于YOLOv5的v5版本和v6版本及Javacr大佬的UI逻辑进行编写

EricFang 12 Dec 28, 2022
Code for the CVPR 2021 paper: Understanding Failures of Deep Networks via Robust Feature Extraction

Welcome to Barlow Barlow is a tool for identifying the failure modes for a given neural network. To achieve this, Barlow first creates a group of imag

Sahil Singla 33 Dec 05, 2022
Vpw analyzer - A visual J1850 VPW analyzer written in Python

VPW Analyzer A visual J1850 VPW analyzer written in Python Requires Tkinter, Pan

7 May 01, 2022
Python version of the amazing Reaction Mechanism Generator (RMG).

Reaction Mechanism Generator (RMG) Description This repository contains the Python version of Reaction Mechanism Generator (RMG), a tool for automatic

Reaction Mechanism Generator 284 Dec 27, 2022
This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge column damage detection

Bridge-damage-segmentation This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge c

Jingxiao Liu 5 Dec 07, 2022
ACL'2021: LM-BFF: Better Few-shot Fine-tuning of Language Models

LM-BFF (Better Few-shot Fine-tuning of Language Models) This is the implementation of the paper Making Pre-trained Language Models Better Few-shot Lea

Princeton Natural Language Processing 607 Jan 07, 2023