PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Overview

Out-of-distribution Generalization Investigation on Vision Transformers

This repository contains PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

A Quick Glance of Our Work

A quick glance of our investigation observations. left: Investigation of IID/OOD Generalization Gap implies that ViTs generalize better than CNNs under most types of distribution shifts. right: Combined with generalization-enhancing methods, we achieve significant performance boosts on the OOD data by 4% compared with vanilla ViTs, and consistently outperform the corresponding CNN models. The enhanced ViTs also have smaller IID/OOD Generalization Gap than the ehhanced BiT models.

Taxonomy of Distribution Shifts

Illustration of our taxonomy of distribution shifts. We build the taxonomy upon what kinds of semantic concepts are modified from the original image. We divide the distribution shifts into five cases: background shifts, corruption shifts, texture shifts, destruction shifts, and style shifts. We apply the proxy -distance (PAD) as an empirical measurement of distribution shifts. We select a representative sample of each distribution shift type and rank them by their PAD values (illustrated nearby the stars), respectively. Please refer to the literature for details.

Datasets Used for Investigation

  • Background Shifts. ImageNet-9 is adopted for background shifts. ImageNet-9 is a variety of 9-class datasets with different foreground-background recombination plans, which helps disentangle the impacts of foreground and background signals on classification. In our case, we use the four varieties of generated background with foreground unchanged, including 'Only-FG', 'Mixed-Same', 'Mixed-Rand' and 'Mixed-Next'. The 'Original' data set is used to represent in-distribution data.
  • Corruption Shifts. ImageNet-C is used to examine generalization ability under corruption shifts. ImageNet-C includes 15 types of algorithmically generated corruptions, grouped into 4 categories: ‘noise’, ‘blur’, ‘weather’, and ‘digital’. Each corruption type has five levels of severity, resulting in 75 distinct corruptions.
  • Texture Shifts. Cue Conflict Stimuli and Stylized-ImageNet are used to investigate generalization under texture shifts. Utilizing style transfer, Geirhos et al. generated Cue Conflict Stimuli benchmark with conflicting shape and texture information, that is, the image texture is replaced by another class with other object semantics preserved. In this case, we respectively report the shape and texture accuracy of classifiers for analysis. Meanwhile, Stylized-ImageNet is also produced in Geirhos et al. by replacing textures with the style of randomly selected paintings through AdaIN style transfer.
  • Destruction Shifts. Random patch-shuffling is utilized for destruction shifts to destruct images into random patches. This process can destroy long-range object information and the severity increases as the split numbers grow. In addition, we make a variant by further divide each patch into two right triangles and respectively shuffle two types of triangles. We name the process triangular patch-shuffling.
  • Style Shifts. ImageNet-R and DomainNet are used for the case of style shifts. ImageNet-R contains 30000 images with various artistic renditions of 200 classes of the original ImageNet validation data set. The renditions in ImageNet-R are real-world, naturally occurring variations, such as paintings or embroidery, with textures and local image statistics which differ from those of ImageNet images. DomainNet is a recent benchmark dataset for large-scale domain adaptation that consists of 345 classes and 6 domains. As labels of some domains are very noisy, we follow the 7 distribution shift scenarios in Saito et al. with 4 domains (Real, Clipart, Painting, Sketch) picked.

Generalization-Enhanced Vision Transformers

A framework overview of the three designed generalization-enhanced ViTs. All networks use a Vision Transformer as feature encoder and a label prediction head . Under this setting, the inputs to the models have labeled source examples and unlabeled target examples. top left: T-ADV promotes the network to learn domain-invariant representations by introducing a domain classifier for domain adversarial training. top right: T-MME leverage the minimax process on the conditional entropy of target data to reduce the distribution gap while learning discriminative features for the task. The network uses a cosine similarity-based classifier architecture to produce class prototypes. bottom: T-SSL is an end-to-end prototype-based self-supervised learning framework. The architecture uses two memory banks and to calculate cluster centroids. A cosine classifier is used for classification in this framework.

Run Our Code

Environment Installation

conda create -n vit python=3.6
conda activate vit
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.0 -c pytorch

Before Running

conda activate vit
PYTHONPATH=$PYTHONPATH:.

Evaluation

CUDA_VISIBLE_DEVICES=0 python main.py \
--model deit_small_b16_384 \
--num-classes 345 \
--checkpoint data/checkpoints/deit_small_b16_384_baseline_real.pth.tar \
--meta-file data/metas/DomainNet/sketch_test.jsonl \
--root-dir data/images/DomainNet/sketch/test

Experimental Results

DomainNet

DeiT_small_b16_384

confusion matrix for the baseline model

clipart painting real sketch
clipart 80.25 33.75 55.26 43.43
painting 36.89 75.32 52.08 31.14
real 50.59 45.81 84.78 39.31
sketch 52.16 35.27 48.19 71.92

Above used models could be found here.

Remarks

  • These results may slightly differ from those in our paper due to differences of the environments.

  • We will continuously update this repo.

Citation

If you find these investigations useful in your research, please consider citing:

@misc{zhang2021delving,  
      title={Delving Deep into the Generalization of Vision Transformers under Distribution Shifts}, 
      author={Chongzhi Zhang and Mingyuan Zhang and Shanghang Zhang and Daisheng Jin and Qiang Zhou and Zhongang Cai and Haiyu Zhao and Shuai Yi and Xianglong Liu and Ziwei Liu},  
      year={2021},  
      eprint={2106.07617},  
      archivePrefix={arXiv},  
      primaryClass={cs.CV}  
}
Owner
Chongzhi Zhang
I am a Master Degree Candidate student, from Beihang University.
Chongzhi Zhang
Arxiv harvester - Poor man's simple harvester for arXiv resources

Poor man's simple harvester for arXiv resources This modest Python script takes

Patrice Lopez 5 Oct 18, 2022
Official DGL implementation of "Rethinking High-order Graph Convolutional Networks"

SE Aggregation This is the implementation for Rethinking High-order Graph Convolutional Networks. Here we show the codes for citation networks as an e

Tianqi Zhang (张天启) 32 Jul 19, 2022
Code for Fold2Seq paper from ICML 2021

[ICML2021] Fold2Seq: A Joint Sequence(1D)-Fold(3D) Embedding-based Generative Model for Protein Design Environment file: environment.yml Data and Feat

International Business Machines 43 Dec 04, 2022
Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP

Wav2CLIP 🚧 WIP 🚧 Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP 📄 🔗 Ho-Hsiang Wu, Prem Seetharaman

Descript 240 Dec 13, 2022
Computer Vision and Pattern Recognition, NUS CS4243, 2022

CS4243_2022 Computer Vision and Pattern Recognition, NUS CS4243, 2022 Cloud Machine #1 : Google Colab (Free GPU) Follow this Notebook installation : h

Xavier Bresson 142 Dec 15, 2022
EfficientDet (Scalable and Efficient Object Detection) implementation in Keras and Tensorflow

EfficientDet This is an implementation of EfficientDet for object detection on Keras and Tensorflow. The project is based on the official implementati

1.3k Dec 19, 2022
maximal update parametrization (µP)

Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer) Paper link | Blog link In Tensor Programs V: Tuning Large Neural Networks

Microsoft 694 Jan 03, 2023
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022
A Small and Easy approach to the BraTS2020 dataset (2D Segmentation)

BraTS2020 A Light & Scalable Solution to BraTS2020 | Medical Brain Tumor Segmentation (2D Segmentation) Developed the segmentation models for segregat

Gunjan Haldar 0 Jan 19, 2022
Genshin-assets - 👧 Public documentation & static assets for Genshin Impact data.

genshin-assets This repo provides easy access to the Genshin Impact assets, primarily for use on static sites. Sources Genshin Optimizer - An Artifact

Zerite Development 5 Nov 22, 2022
Automated Melanoma Recognition in Dermoscopy Images via Very Deep Residual Networks

Introduction This repository contains the modified caffe library and network architectures for our paper "Automated Melanoma Recognition in Dermoscopy

Lequan Yu 47 Nov 24, 2022
Boostcamp CV Serving For Python

Boostcamp-CV-Serving Prerequisites MySQL GCP Cloud Storage GCP key file Sentry Streamlit Cloud Secrets: .streamlit/secrets.toml #DO NOT SHARE THIS I

Jungwon Seo 19 Feb 22, 2022
Myia prototyping

Myia Myia is a new differentiable programming language. It aims to support large scale high performance computations (e.g. linear algebra) and their g

Mila 456 Nov 07, 2022
code for CVPR paper Zero-shot Instance Segmentation

Code for CVPR2021 paper Zero-shot Instance Segmentation Code requirements python: python3.7 nvidia GPU pytorch1.1.0 GCC =5.4 NCCL 2 the other python

zhengye 86 Dec 13, 2022
Make your AirPlay devices as TTS speakers

Apple AirPlayer Home Assistant integration component, make your AirPlay devices as TTS speakers. Before Use 2021.6.X or earlier Apple Airplayer compon

George Zhao 117 Dec 15, 2022
Code and model benchmarks for "SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology"

NeurIPS 2020 SEVIR Code for paper: SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology Requirement

USAF - MIT Artificial Intelligence Accelerator 46 Dec 15, 2022
Code for "Continuous-Time Meta-Learning with Forward Mode Differentiation" (ICLR 2022)

Continuous-Time Meta-Learning with Forward Mode Differentiation ICLR 2022 (Spotlight) - Installation - Example - Citation This repository contains the

Tristan Deleu 25 Oct 20, 2022
sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

445 Jan 02, 2023
Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective

Unofficial pytorch implementation of the paper "Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective"

16 Nov 21, 2022
Using PyTorch Perform intent classification using three different models to see which one is better for this task

Using PyTorch Perform intent classification using three different models to see which one is better for this task

Yoel Graumann 1 Feb 14, 2022