Exploring whether attention is necessary for vision transformers

Overview

Do You Even Need Attention?
A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet

Paper/Report

TL;DR

We replace the attention layer in a vision transformer with a feed-forward layer and find that it still works quite well on ImageNet.

Abstract

The strong performance of vision transformers on image classification and other vision tasks is often attributed to the design of their multi-head attention layers. However, the extent to which attention is responsible for this strong performance remains unclear. In this short report, we ask: is the attention layer even necessary? Specifically, we replace the attention layer in a vision transformer with a feed-forward layer applied over the patch dimension. The resulting architecture is simply a series of feed-forward layers applied over the patch and feature dimensions in an alternating fashion. In experiments on ImageNet, this architecture performs surprisingly well: a ViT/DeiT-base-sized model obtains 74.9% top-1 accuracy, compared to 77.9% and 79.9% for ViT and DeiT respectively. These results indicate that aspects of vision transformers other than attention, such as the patch embedding, may be more responsible for their strong performance than previously thought. We hope these results prompt the community to spend more time trying to understand why our current models are as effective as they are.

Note

This is concurrent research with MLP-Mixer from Google Research. The ideas are exacty the same, with the one difference being that they use (a lot) more compute.

Pretrained models and logs

Here is a Weights and Biases report with the expected training trajectory: W&B

name [email protected] #params url
FF-tiny 61.4 7.7M model
FF-base 74.9 62M model
FF-large 71.4 206M -

Note: I haven't uploaded the FF-Large model because (1) it's over GitHub's file storage limit, and (2) I don't see why anyone would want it, given that it performs worse than the base model. That being said, if you want it, reach out to me and I'll send it to you.

How to train

The model definition in vision_transformer_linear.py is designed to be run with the repo from DeiT, which is itself based on the wonderful timm package.

Steps:

  • Clone the DeiT repo and move the file into it
git clone https://github.com/facebookresearch/deit
mv vision_transformer_linear.py deit
cd deit
  • Add a line to import vision_transformer_linear in main.py. For example, add the following after the import statements (around line 27):
+ import vision_transformer_linear
  • Train:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
--nproc_per_node=8 \
--master_port 10490 \
--use_env main.py \
--model linear_tiny \
--batch-size 128 \
--drop 0.1 \
--output_dir outputs/linear-tiny \
--data-path your/path/to/imagenet

Citation

If you build upon this idea, feel free to drop a citation (and also cite MLP-Mixer).

@article{melaskyriazi2021doyoueven,
  title={Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet},
  author={Luke Melas-Kyriazi},
  journal=arxiv,
  year=2021
}
You might also like...
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.

Object-Placement-Assessment-Dataset-OPA Object-Placement-Assessment (OPA) is to verify whether a composite image is plausible in terms of the object p

Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.
Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.

face-mask-detection Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network. It contains 3 scr

This program was designed to detect whether someone is wearing a facemask through a live video stream.

This program was designed to detect whether someone is wearing a facemask through a live video stream. A custom lightweight CNN trained with TensorFlow on a public dataset provided by Kaggle is used to detect whether each face detected by the cv2 face detection dnn is wearing a mask

A system used to detect whether a person is wearing a medical mask or not.

Mask_Detection_System A system used to detect whether a person is wearing a medical mask or not. To open the program, please follow these steps: Make

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Comments
  • On details about the experiments

    On details about the experiments

    In the experiment section of the report, you mention:

    Such a comparison is not exactly fair because the feed-forward model uses stronger training augmentations.

    I wonder what the augmentations are, for details about the augmentation seems missing from the paper.

    opened by boredtylin 2
  • Positional Encoding Ablation

    Positional Encoding Ablation

    Hi Luke, thank you for sharing this amazing work.

    In your arxiv document, I cannot find any mention of positional encoding, but I see that you use them in your code. Did you conduct any ablation study on the PE? i.e., how much does it affect the performance, with and without?

    Thank you in advance.

    opened by jjparkcv 0
  • Interaction between patches through a transpose may have a stronger role to play ?

    Interaction between patches through a transpose may have a stronger role to play ?

    Hi, I was going through your exp report. You have made a point that since you are able to get a good performance without using attention layer so good performance of ViT may be more to do with it's embedding layer than attention .

    But I believe, It's also may be to do with how you have established an interaction between patches through a transpose very similar to what was done in MLP-Mixer .

    Would love to know your thoughts on this ?

    opened by rakshith291 0
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
Official Code Implementation of the paper : XAI for Transformers: Better Explanations through Conservative Propagation

Official Code Implementation of The Paper : XAI for Transformers: Better Explanations through Conservative Propagation For the SST-2 and IMDB expermin

Ameen Ali 23 Dec 30, 2022
This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".

Artistic Style Transfer with Internal-external Learning and Contrastive Learning This is the official PyTorch implementation of our paper: "Artistic S

51 Dec 20, 2022
Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Database

Python cx_Oracle Notebooks, 2022 The repository contains Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Da

Christopher Jones 13 Dec 15, 2022
Code for Transformer Hawkes Process, ICML 2020.

Transformer Hawkes Process Source code for Transformer Hawkes Process (ICML 2020). Run the code Dependencies Python 3.7. Anaconda contains all the req

Simiao Zuo 111 Dec 26, 2022
Game Agent Framework. Helping you create AIs / Bots that learn to play any game you own!

Serpent.AI - Game Agent Framework (Python) Update: Revival (May 2020) Development work has resumed on the framework with the aim of bringing it into 2

Serpent.AI 6.4k Jan 05, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 01, 2023
Dark Finix: All in one hacking framework with almost 100 tools

Dark Finix - Hacking Framework. Dark Finix is a all in one hacking framework wit

Md. Nur habib 2 Feb 18, 2022
Meta Self-learning for Multi-Source Domain Adaptation: A Benchmark

Meta Self-Learning for Multi-Source Domain Adaptation: A Benchmark Project | Arxiv | YouTube | | Abstract In recent years, deep learning-based methods

CVSM Group - email: <a href=[email protected]"> 188 Dec 12, 2022
Implementation of ToeplitzLDA for spatiotemporal stationary time series data.

Code for the ToeplitzLDA classifier proposed in here. The classifier conforms sklearn and can be used as a drop-in replacement for other LDA classifiers. For in-depth usage refer to the learning from

Jan Sosulski 5 Nov 07, 2022
A PyTorch Implementation of Single Shot MultiBox Detector

SSD: Single Shot MultiBox Object Detector, in PyTorch A PyTorch implementation of Single Shot MultiBox Detector from the 2016 paper by Wei Liu, Dragom

Max deGroot 4.8k Jan 07, 2023
A simple implementation of Kalman filter in single object tracking

kalman-filter-in-single-object-tracking A simple implementation of Kalman filter in single object tracking https://www.bilibili.com/video/BV1Qf4y1J7D4

130 Dec 26, 2022
Gender Classification Machine Learning Model using Sk-learn in Python with 97%+ accuracy and deployment

Gender-classification This is a ML model to classify Male and Females using some physical characterstics Data. Python Libraries like Pandas,Numpy and

Aryan raj 11 Oct 16, 2022
Rational Activation Functions - Replacing Padé Activation Units

Rational Activations - Learnable Rational Activation Functions First introduce as PAU in Padé Activation Units: End-to-end Learning of Activation Func

<a href=[email protected]"> 38 Nov 22, 2022
This is the pytorch implementation for the paper: *Learning Accurate Performance Predictors for Ultrafast Automated Model Compression*, which is in submission to TPAMI

SeerNet This is the pytorch implementation for the paper: Learning Accurate Performance Predictors for Ultrafast Automated Model Compression, which is

3 May 01, 2022
DECAF: Generating Fair Synthetic Data Using Causally-Aware Generative Networks

DECAF (DEbiasing CAusal Fairness) Code Author: Trent Kyono This repository contains the code used for the "DECAF: Generating Fair Synthetic Data Using

van_der_Schaar \LAB 7 Nov 24, 2022
Camview - A CLI-tool used to stream CCTV online footage based on URL params

CamView A CLI-tool used to stream CCTV online footage based on URL params Get St

Finn Lancaster 54 Dec 09, 2022
This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming"

Coresets via Bilevel Optimization This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming" ht

Zalán Borsos 51 Dec 30, 2022
This repository compare a selfie with images from identity documents and response if the selfie match.

aws-rekognition-facecompare This repository compare a selfie with images from identity documents and response if the selfie match. This code was made

1 Jan 27, 2022
PowerGridworld: A Framework for Multi-Agent Reinforcement Learning in Power Systems

PowerGridworld provides users with a lightweight, modular, and customizable framework for creating power-systems-focused, multi-agent Gym environments that readily integrate with existing training fr

National Renewable Energy Laboratory 37 Dec 17, 2022
Official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML)

Automated Learning Rate Scheduler for Large-Batch Training The official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th

Kakao Brain 35 Jan 04, 2023