(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

Overview

How Do Vision Transformers Work?

This repository provides a PyTorch implementation of "How Do Vision Transformers Work?" In the paper, we show that multi-head self-attentions (MSAs) for computer vision is NOT for capturing long-range dependency. In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):

  1. What properties of MSAs do we need to better optimize NNs? Do the long-range dependencies of MSAs help NNs learn?
  2. Do MSAs act like Convs? If not, how are they different?
  3. How can we harmonize MSAs with Convs? Can we just leverage their advantages?

We demonstrate that (1) MSAs flatten the loss landscapes, (2) MSA and Convs are complementary because MSAs are low-pass filters and convolutions (Convs) are high-pass filter, and (3) MSAs at the end of a stage significantly improve the accuracy.

Let's find the detailed answers below!

I. What Properties of MSAs Do We Need to Improve Optimization?

MSAs improve not only accuracy but also generalization by flattening the loss landscapes. Such improvement is primarily attributable to their data specificity, NOT long-range dependency 😱 Their weak inductive bias disrupts NN training. On the other hand, ViTs suffers from non-convex losses. MSAs allow negative Hessian eigenvalues in small data regimes. Large datasets and loss landscape smoothing methods alleviate this problem.

II. Do MSAs Act Like Convs?

MSAs and Convs exhibit opposite behaviors. For example, MSAs are low-pass filters, but Convs are high-pass filters. In addition, Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.

III. How Can We Harmonize MSAs With Convs?

Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage play a key role in prediction. Based on these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of canonical Transformer, which has one MLP block for one MSA block.


In addition, we also introduce AlterNet, a model in which Conv blocks at the end of a stage are replaced with MSA blocks. Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes. This contrasts with canonical ViTs, models that perform poorly on small amounts of data.

This repository is based on the official implementation of "Blurs Make Results Clearer: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness". In this paper, we show that a simple (non-trainable) 2 ✕ 2 box blur filter improves accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs. MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!

Getting Started

The following packages are required:

  • pytorch
  • matplotlib
  • notebook
  • ipywidgets
  • timm
  • einops
  • tensorboard
  • seaborn (optional)

We mainly use docker images pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime for the code.

See classification.ipynb for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.

Metrics. We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, ↑) and Acc for 90% certain results (Acc-90, ↑), negative log-likelihood (NLL, ↓), Expected Calibration Error (ECE, ↓), Intersection-over-Union (IoU, ↑) and IoU for certain results (IoU-90, ↑), Unconfidence (Unc-90, ↑), and Frequency for certain results (Freq-90, ↑). We also define a method to plot a reliability diagram for visualization.

Models. We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default.

Visualizing the Loss Landscapes

Refer to losslandscape.ipynb for exploring the loss landscapes. It requires a trained model. Run all cells to get predictive performance of the model for weight space grid. We provide a sample loss landscape result.

Evaluating Robustness on Corrupted Datasets

Refer to robustness.ipynb for evaluation corruption robustness on corrupted datasets such as CIFAR-10-C and CIFAR-100-C. It requires a trained model. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each. We provide a sample robustness result.

How to Apply MSA to Your Own Model

We find that MSA complements Conv (not replaces Conv), and MSA closer to the end of stage improves predictive performance significantly. Based on these insights, we propose the following build-up rules:

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

In the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in c3 harm the accuracy, but the MSA at the end of c2 improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regime, e.g., CIFAR!

Caution: Investigate Loss Landscapes and Hessians With l2 Regularization on Augmented Datasets

Two common mistakes ⚠️ are investigating loss landscapes and Hessians (1) 'without considering l2 regularization' on (2) 'clean datasets'. However, note that NNs are optimized with l2 regularization on augmented datasets. Therefore, it is appropriate to visualize 'NLL + l2' on 'augmented datasets'. Measuring criteria without l2 on clean dataset would give incorrect (even opposite) results.

Citation

If you find this useful, please consider citing 📑 the paper and starring 🌟 this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: xxxnell) with any comments or feedback.

BibTex is TBD.

License

All code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the PyTorch Image Models and Vision Transformer - Pytorch which are Apache 2.0 and MIT licensed.

Copyright the maintainers.

Owner
xxxnell
Programmer & ML researcher
xxxnell
Proof-Of-Concept Piano-Drums Music AI Model/Implementation

Rock Piano "When all is one and one is all, that's what it is to be a rock and not to roll." ---Led Zeppelin, "Stairway To Heaven" Proof-Of-Concept Pi

Alex 4 Nov 28, 2021
Pipeline for employing a Lightweight deep learning models for LOW-power systems

PL-LOW A high-performance deep learning model lightweight pipeline that gradually lightens deep neural networks in order to utilize high-performance d

POSTECH Data Intelligence Lab 9 Aug 13, 2022
Sign Language is detected in realtime using video sequences. Our approach involves MediaPipe Holistic for keypoints extraction and LSTM Model for prediction.

RealTime Sign Language Detection using Action Recognition Approach Real-Time Sign Language is commonly predicted using models whose architecture consi

Rishikesh S 15 Aug 20, 2022
TigerLily: Finding drug interactions in silico with the Graph.

Drug Interaction Prediction with Tigerlily Documentation | Example Notebook | Youtube Video | Project Report Tigerlily is a TigerGraph based system de

Benedek Rozemberczki 91 Dec 30, 2022
Code for Paper: Self-supervised Learning of Motion Capture

Self-supervised Learning of Motion Capture This is code for the paper: Hsiao-Yu Fish Tung, Hsiao-Wei Tung, Ersin Yumer, Katerina Fragkiadaki, Self-sup

Hsiao-Yu Fish Tung 87 Jul 25, 2022
PyTorch Implementation of Unsupervised Depth Completion with Calibrated Backprojection Layers (ORAL, ICCV 2021)

Unsupervised Depth Completion with Calibrated Backprojection Layers PyTorch implementation of Unsupervised Depth Completion with Calibrated Backprojec

80 Dec 13, 2022
Grow Function: Generate 3D Stacked Bifurcating Double Deep Cellular Automata based organisms which differentiate using a Genetic Algorithm...

Grow Function: A 3D Stacked Bifurcating Double Deep Cellular Automata which differentiates using a Genetic Algorithm... TLDR;High Def Trees that you can mint as NFTs on Solana

Nathaniel Gibson 4 Oct 08, 2022
Pytorch implementation of the DeepDream computer vision algorithm

deep-dream-in-pytorch Pytorch (https://github.com/pytorch/pytorch) implementation of the deep dream (https://en.wikipedia.org/wiki/DeepDream) computer

102 Dec 05, 2022
Crowd-Kit is a powerful Python library that implements commonly-used aggregation methods for crowdsourced annotation and offers the relevant metrics and datasets

Crowd-Kit: Computational Quality Control for Crowdsourcing Documentation Crowd-Kit is a powerful Python library that implements commonly-used aggregat

Toloka 125 Dec 30, 2022
Graph InfoClust: Leveraging cluster-level node information for unsupervised graph representation learning

Graph-InfoClust-GIC [PAKDD 2021] PAKDD'21 version Graph InfoClust: Maximizing Coarse-Grain Mutual Information in Graphs Preprint version Graph InfoClu

Costas Mavromatis 21 Dec 03, 2022
A tool to analyze leveraged liquidity mining and find optimal option combination for hedging.

LP-Option-Hedging Description A Python program to analyze leveraged liquidity farming/mining and find the optimal option combination for hedging imper

Aureliano 18 Dec 19, 2022
VGGVox models for Speaker Identification and Verification trained on the VoxCeleb (1 & 2) datasets

VGGVox models for speaker identification and verification This directory contains code to import and evaluate the speaker identification and verificat

338 Dec 27, 2022
Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

🦩 Flamingo - Pytorch Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the p

Phil Wang 630 Dec 28, 2022
C3DPO - Canonical 3D Pose Networks for Non-rigid Structure From Motion.

C3DPO: Canonical 3D Pose Networks for Non-Rigid Structure From Motion By: David Novotny, Nikhila Ravi, Benjamin Graham, Natalia Neverova, Andrea Vedal

Meta Research 309 Dec 16, 2022
This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports"

Introduction: X-Ray Report Generation This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports". O

no name 36 Dec 16, 2022
Contains modeling practice materials and homework for the Computational Neuroscience course at Okinawa Institute of Science and Technology

A310 Computational Neuroscience - Okinawa Institute of Science and Technology, 2022 This repository contains modeling practice materials and homework

Sungho Hong 1 Jan 24, 2022
Go from graph data to a secure and interactive visual graph app in 15 minutes. Batteries-included self-hosting of graph data apps with Streamlit, Graphistry, RAPIDS, and more!

✔️ Linux ✔️ OS X ❌ Windows (#39) Welcome to graph-app-kit Turn your graph data into a secure and interactive visual graph app in 15 minutes! Why This

Graphistry 107 Jan 02, 2023
git git《Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking》(CVPR 2021) GitHub:git2] 《Masksembles for Uncertainty Estimation》(CVPR 2021) GitHub:git3]

Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking Ning Wang, Wengang Zhou, Jie Wang, and Houqiang Li Accepted by CVPR

NingWang 236 Dec 22, 2022
This is the repository of shape matching algorithm Iterative Rotations and Assignments (IRA)

Description This is the repository of shape matching algorithm Iterative Rotations and Assignments (IRA), described in the publication [1]. Directory

MAMMASMIAS Consortium 6 Nov 14, 2022
Losslandscapetaxonomy - Taxonomizing local versus global structure in neural network loss landscapes

Taxonomizing local versus global structure in neural network loss landscapes Int

Yaoqing Yang 8 Dec 30, 2022