Exploring whether attention is necessary for vision transformers

Overview

Do You Even Need Attention?
A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet

Paper/Report

TL;DR

We replace the attention layer in a vision transformer with a feed-forward layer and find that it still works quite well on ImageNet.

Abstract

The strong performance of vision transformers on image classification and other vision tasks is often attributed to the design of their multi-head attention layers. However, the extent to which attention is responsible for this strong performance remains unclear. In this short report, we ask: is the attention layer even necessary? Specifically, we replace the attention layer in a vision transformer with a feed-forward layer applied over the patch dimension. The resulting architecture is simply a series of feed-forward layers applied over the patch and feature dimensions in an alternating fashion. In experiments on ImageNet, this architecture performs surprisingly well: a ViT/DeiT-base-sized model obtains 74.9% top-1 accuracy, compared to 77.9% and 79.9% for ViT and DeiT respectively. These results indicate that aspects of vision transformers other than attention, such as the patch embedding, may be more responsible for their strong performance than previously thought. We hope these results prompt the community to spend more time trying to understand why our current models are as effective as they are.

Note

This is concurrent research with MLP-Mixer from Google Research. The ideas are exacty the same, with the one difference being that they use (a lot) more compute.

Pretrained models and logs

Here is a Weights and Biases report with the expected training trajectory: W&B

name [email protected] #params url
FF-tiny 61.4 7.7M model
FF-base 74.9 62M model
FF-large 71.4 206M -

Note: I haven't uploaded the FF-Large model because (1) it's over GitHub's file storage limit, and (2) I don't see why anyone would want it, given that it performs worse than the base model. That being said, if you want it, reach out to me and I'll send it to you.

How to train

The model definition in vision_transformer_linear.py is designed to be run with the repo from DeiT, which is itself based on the wonderful timm package.

Steps:

  • Clone the DeiT repo and move the file into it
git clone https://github.com/facebookresearch/deit
mv vision_transformer_linear.py deit
cd deit
  • Add a line to import vision_transformer_linear in main.py. For example, add the following after the import statements (around line 27):
+ import vision_transformer_linear
  • Train:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
--nproc_per_node=8 \
--master_port 10490 \
--use_env main.py \
--model linear_tiny \
--batch-size 128 \
--drop 0.1 \
--output_dir outputs/linear-tiny \
--data-path your/path/to/imagenet

Citation

If you build upon this idea, feel free to drop a citation (and also cite MLP-Mixer).

@article{melaskyriazi2021doyoueven,
  title={Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet},
  author={Luke Melas-Kyriazi},
  journal=arxiv,
  year=2021
}
You might also like...
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.

Object-Placement-Assessment-Dataset-OPA Object-Placement-Assessment (OPA) is to verify whether a composite image is plausible in terms of the object p

Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.
Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.

face-mask-detection Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network. It contains 3 scr

This program was designed to detect whether someone is wearing a facemask through a live video stream.

This program was designed to detect whether someone is wearing a facemask through a live video stream. A custom lightweight CNN trained with TensorFlow on a public dataset provided by Kaggle is used to detect whether each face detected by the cv2 face detection dnn is wearing a mask

A system used to detect whether a person is wearing a medical mask or not.

Mask_Detection_System A system used to detect whether a person is wearing a medical mask or not. To open the program, please follow these steps: Make

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Comments
  • On details about the experiments

    On details about the experiments

    In the experiment section of the report, you mention:

    Such a comparison is not exactly fair because the feed-forward model uses stronger training augmentations.

    I wonder what the augmentations are, for details about the augmentation seems missing from the paper.

    opened by boredtylin 2
  • Positional Encoding Ablation

    Positional Encoding Ablation

    Hi Luke, thank you for sharing this amazing work.

    In your arxiv document, I cannot find any mention of positional encoding, but I see that you use them in your code. Did you conduct any ablation study on the PE? i.e., how much does it affect the performance, with and without?

    Thank you in advance.

    opened by jjparkcv 0
  • Interaction between patches through a transpose may have a stronger role to play ?

    Interaction between patches through a transpose may have a stronger role to play ?

    Hi, I was going through your exp report. You have made a point that since you are able to get a good performance without using attention layer so good performance of ViT may be more to do with it's embedding layer than attention .

    But I believe, It's also may be to do with how you have established an interaction between patches through a transpose very similar to what was done in MLP-Mixer .

    Would love to know your thoughts on this ?

    opened by rakshith291 0
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
Instance-based label smoothing for improving deep neural networks generalization and calibration

Instance-based Label Smoothing for Neural Networks Pytorch Implementation of the algorithm. This repository includes a new proposed method for instanc

Mohamed Maher 1 Aug 13, 2022
Generative code template for PixelBeasts 10k NFT project.

generator-template Generative code template for combining transparent png attributes into 10,000 unique images. Used for the PixelBeasts 10k NFT proje

Yohei Nakajima 9 Aug 24, 2022
A scikit-learn-compatible module for estimating prediction intervals.

MAPIE - Model Agnostic Prediction Interval Estimator MAPIE allows you to easily estimate prediction intervals (or prediction sets) using your favourit

588 Jan 04, 2023
Implementation of MA-Trace - a general-purpose multi-agent RL algorithm for cooperative environments.

Off-Policy Correction For Multi-Agent Reinforcement Learning This repository is the official implementation of Off-Policy Correction For Multi-Agent R

4 Aug 18, 2022
My tensorflow implementation of "A neural conversational model", a Deep learning based chatbot

Deep Q&A Table of Contents Presentation Installation Running Chatbot Web interface Results Pretrained model Improvements Upgrade Presentation This wor

Conchylicultor 2.9k Dec 28, 2022
A Kaggle competition: discriminate gender based on handwriting

Gender discrimination based on handwriting See http://fastml.com/gender-discrimination/ for description. prep_data.py - a first step chunk_by_authors.

Zygmunt Zając 22 Jul 20, 2022
3D-Reconstruction 基于深度学习方法的单目多视图三维重建

基于深度学习方法的单目多视图三维重建 Part I 三维重建 代码:Part1 技术文档:[Markdown] [PDF] 原始图像:Original Images 点云结果:Point Cloud Results-1

HMT_Curo 19 Dec 26, 2022
This repository is for the preprint "A generative nonparametric Bayesian model for whole genomes"

BEAR Overview This repository contains code associated with the preprint A generative nonparametric Bayesian model for whole genomes (2021), which pro

Debora Marks Lab 10 Sep 18, 2022
Meandering In Networks of Entities to Reach Verisimilar Answers

MINERVA Meandering In Networks of Entities to Reach Verisimilar Answers Code and models for the paper Go for a Walk and Arrive at the Answer - Reasoni

Shehzaad Dhuliawala 271 Dec 13, 2022
a practicable framework used in Deep Learning. So far UDL only provide DCFNet implementation for the ICCV paper (Dynamic Cross Feature Fusion for Remote Sensing Pansharpening)

UDL UDL is a practicable framework used in Deep Learning (computer vision). Benchmark codes, results and models are available in UDL, please contact @

Xiao Wu 11 Sep 30, 2022
SAN for Product Attributes Prediction

SAN Heterogeneous Star Graph Attention Network for Product Attributes Prediction This repository contains the official PyTorch implementation for ADVI

Xuejiao Zhao 9 Dec 12, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 05, 2023
Nested cross-validation is necessary to avoid biased model performance in embedded feature selection in high-dimensional data with tiny sample sizes

Pruner for nested cross-validation - Sphinx-Doc Nested cross-validation is necessary to avoid biased model performance in embedded feature selection i

1 Dec 15, 2021
This is a computer vision based implementation of the popular childhood game 'Hand Cricket/Odd or Even' in python

Hand Cricket Table of Content Overview Installation Game rules Project Details Future scope Overview This is a computer vision based implementation of

Abhinav R Nayak 6 Jan 12, 2022
Data and code from COVID-19 machine learning paper

Machine learning approaches for localized lockdown, subnotification analysis and cases forecasting in São Paulo state counties during COVID-19 pandemi

Sara Malvar 4 Dec 22, 2022
The Few-Shot Bot: Prompt-Based Learning for Dialogue Systems

Few-Shot Bot: Prompt-Based Learning for Dialogue Systems This repository includes the dataset, experiments results, and code for the paper: Few-Shot B

Andrea Madotto 103 Dec 28, 2022
PyTorch reimplementation of the paper Involution: Inverting the Inherence of Convolution for Visual Recognition [CVPR 2021].

Involution: Inverting the Inherence of Convolution for Visual Recognition Unofficial PyTorch reimplementation of the paper Involution: Inverting the I

Christoph Reich 100 Dec 01, 2022
DANA paper supplementary materials

DANA Supplements This repository stores the data, results, and R scripts to generate these reuslts and figures for the corresponding paper Depth Norma

0 Dec 17, 2021
This porject is intented to build the most accurate model for predicting the porbability of loan default

Estimating-Loan-Default-Probability IBA ML2 Mid-project / Kaggle Competition This porject is intented to build the most accurate model for predicting

Adil Gahramanov 1 Jan 24, 2022
Keras Realtime Multi-Person Pose Estimation - Keras version of Realtime Multi-Person Pose Estimation project

This repository has become incompatible with the latest and recommended version of Tensorflow 2.0 Instead of refactoring this code painfully, I create

M Faber 769 Dec 08, 2022