PoolFormer: MetaFormer is Actually What You Need for Vision

Overview

PoolFormer: MetaFormer is Actually What You Need for Vision (arXiv)

This is a PyTorch implementation of PoolFormer proposed by our paper "MetaFormer is Actually What You Need for Vision".

MetaFormer

Figure 1: MetaFormer and performance of MetaFormer-based models on ImageNet-1K validation set. We argue that the competence of transformer/MLP-like models primarily stems from the general architecture MetaFormer instead of the equipped specific token mixers. To demonstrate this, we exploit an embarrassingly simple non-parametric operator, pooling, to conduct extremely basic token mixing. Surprisingly, the resulted model PoolFormer consistently outperforms the DeiT and ResMLP as shown in (b), which well supports that MetaFormer is actually what we need to achieve competitive performance.

PoolFormer Figure 2: (a) The overall framework of PoolFormer. (b) The architecture of PoolFormer block. Compared with transformer block, it replaces attention with an extremely simple non-parametric operator, pooling, to conduct only basic token mixing.

Bibtex

@article{yu2021metaformer,
  title={MetaFormer is Actually What You Need for Vision},
  author={Yu, Weihao and Luo, Mi and Zhou, Pan and Si, Chenyang and Zhou, Yichen and Wang, Xinchao and Feng, Jiashi and Yan, Shuicheng},
  journal={arXiv preprint arXiv:2111.11418},
  year={2021}
}

1. Requirements

For Image Classification (Configs of detection and segmentation will be available soon)

torch>=1.7.0; torchvision>=0.8.0; pyyaml; apex-amp (if you want to use fp16); timm (pip install git+https://github.com/rwightman/p[email protected])

data prepare: ImageNet with the following folder structure, you can extract ImageNet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Directory structure in this repo:

│poolformer/
├──misc/
├──models/
│  ├── __init__.py
│  ├── poolformer.py
├──LICENSE
├──README.md
├──distributed_train.sh
├──train.py
├──validate.py

2. PoolFormer Models

Model #params Image resolution Top1 Acc Download
poolformer_s12 12M 224 77.2 here
poolformer_s24 21M 224 80.3 here
poolformer_s36 31M 224 81.4 here
poolformer_m36 56M 224 82.1 here
poolformer_m48 73M 224 82.5 here

All the pretrained models can also be downloaded by BaiDu Yun (password: esac).

Update ResNet Scores in the paper

Updated_ResNet_Scores

[1] He et al., "Deep Residual Learning for Image Recognition", CVPR 2016.

[2] Wightman et al., "Resnet strikes back: An improved training procedure in timm", arXiv preprint arXiv:2110.00476. 2021 Oct 1.

Usage

We also provide a Colab notebook which run the steps to perform inference with poolformer.

3. Validation

To evaluate our PoolFormer models, run:

MODEL=poolformer_s12 # poolformer_{s12, s24, s36, m36, m48}
python3 validate.py /path/to/imagenet  --model $MODEL \
  --checkpoint /path/to/checkpoint -b 128

4. Train

We show how to train PoolFormers on 8 GPUs. The relation between learning rate and batch size is lr=bs/1024*1e-3. For convenience, assuming the batch size is 1024, then the learning rate is set as 1e-3 (for batch size of 1024, setting the learning rate as 2e-3 sometimes sees better performance).

MODEL=poolformer_s12 # poolformer_{s12, s24, s36, m36, m48}
DROP_PATH=0.1 # drop path rates [0.1, 0.1, 0.2, 0.3, 0.4] responding to model [s12, s24, s36, m36, m48]
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model $MODEL -b 128 --lr 1e-3 --drop-path $DROP_PATH --apex-amp

5. Acknowledgment

Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.

pytorch-image-models, mmdetection, mmsegmentation.

Besides, Weihao Yu would like to thank TPU Research Cloud (TRC) program for the support of partial computational resources.

LICENSE

This repo is under the Apache-2.0 license. For commercial use, please contact the authors.

Comments
  • About Normalization

    About Normalization

    Hi, thanks for your excellent work. In your ablation studies (section 4.4), you compared Group Normalization (group number is set as 1 for simplicity), Layer Normalization, and Batch Normalization. The conclusion is that Group Normalization is 0.7% or 0.8% higher than Layer Normalization or Batch Normalization. But when the number of groups is 1, Group Normalization is equivalent to Layer Normalization, right?

    opened by tinyalpha 10
  • Addition of the Organization on HuggingFace Transformers

    Addition of the Organization on HuggingFace Transformers

    Hello PoolFormer team!

    I have been working on porting the implementation of PoolFormer to HuggingFace Transformers library (you can see my PR here) and I was wondering if I can go ahead and add Sea AI labs as an organization to the HuggingFace models hub.

    This will allow all model checkpoints to be uploaded onto the hub as well as model cards, etc.

    Kind regards, Tanay Mehta

    opened by heytanay 7
  • How to measure MACs?

    How to measure MACs?

    Hi, thanks for your nice work :) I also watched your presentation record through this conference.

    I want to apply the poolformer for my work, can I ask how did you measure the MACs of the architecture introduced in your paper? Or if you were not bothered, I want to ask if I could be shared your measurement code.

    opened by DoranLyong 5
  • why use use_layer_scale

    why use use_layer_scale

    thanks for your great contribution! in the implement for poolformerblock ,there is a layer_scale after token_mixer. What is the impact of this operation?

    opened by rtfgithub 5
  • Invitation of making PR for OpenMMLab / MMSegmentation.

    Invitation of making PR for OpenMMLab / MMSegmentation.

    Hi, first congrats for acceptance of CVPR'2022. This work deserves because it is very great.

    I am a member of OpenMMLab and mainly work for developing MMSegmentation. I think if it supported officially, many more people would use it for benchmark, which would promote research in computer vision area.

    Would you like to make PR for openmmlab? We could discuss together to refactor your code and use our own GPUs to train & re-implement.

    I think it is pretty cool because it would make more reseachers and community members use this excellent work! Here is our re-implementing work: ConvNeXt.

    We do hope PoolFormer could also be added as backbones in our codebase so that many researchers could use directly it for downstream tasks.

    Looking forward to your reply!

    Best,

    opened by MengzhangLI 5
  • why the speed slower than pvtv2-b1?

    why the speed slower than pvtv2-b1?

    Recently I trained a transformer based instance seg model, tested with different backbone, here is the result and speed test:

    image

    batchsize is training batchsize. Why the speed of poolformer is the slowest one? is that normal?

    Slower than pvtv2-b1 and precision less than it...

    opened by jinfagang 5
  • Checkpoints of the Ablation study

    Checkpoints of the Ablation study

    Hi, thanks for your amazing work. I am reading the Tab 6, and I am surprised because the method is so simple and very effective, especially when the Pooling is replaced with Identity Mapping. Top1 74.3 on ImageNet-1k with only Conv1x1 and Norm layer. I am thrilled... Can you release this checkpoint so that we can verify. Thanks again. image

    opened by chuong98 5
  • Design on positional embedding?

    Design on positional embedding?

    Hello authors,

    I appreciate a lot your current work, which inspired the community. I am here to raise a very simple and quick question after checking the code and architecture design.

    I observed that network using pooling, MLP or identical as token mixer, you do not include positional embedding, while you consider this component only when you use MHA. What is the concern of this design and why other models do not rely on this embedding?

    Best,

    discussion 
    opened by jizongFox 4
  • Error: About self.pool(x)

    Error: About self.pool(x)

    Hello, I am more interested in the poolformer you proposed, but an error occurred during the use of PoolFormerBlock, as follows: Traceback (most recent call last): File "train.py", line 545, in train(hyp, opt, device, tb_writer) File "train.py", line 89, in train model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create File "E:\Work\yolov5\models\yolo.py", line 106, in init m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward File "E:\Work\yolov5\models\yolo.py", line 138, in forward return self.forward_once(x, profile) # single-scale inference, train File "E:\Work\yolov5\models\yolo.py", line 157, in forward_once x = m(x) # run # 执行网络组件操作 File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "E:\Work\yolov5_T23\models\common.py", line 194, in forward n = self.token_mixer(m) File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "E:\Work\yolov5_T23\models\Confor_VC.py", line 93, in forward x1 = self.pool(x) - x # x1 = self.pool(x) - x File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\pooling.py", line 594, in forward return F.avg_pool2d(input, self.kernel_size, self.stride, TypeError: avg_pool2d(): argument 'kernel_size' (position 2) must be tuple of ints, not bool

    I want to put the poolformer behind a ConvBlock and the above problem occurred。 thank you!

    opened by QY1994-0919 4
  • About MLN(Modified Layer Normalization)

    About MLN(Modified Layer Normalization)

    This paper provides new perspectives about Transformer block, but I have some questions about one of the details. As far as I know, the LayerNorm officially provided by Pytorch implements the same function as the MLN, which computes the mean and variance along token and channel dimensions. So where is the improvement? image The official example : #Image Example N, C, H, W = 20, 5, 10, 10 input = torch.randn(N, C, H, W) #Normalize over the last three dimensions (i.e. the channel and spatial dimensions) #as shown in the image below layer_norm = nn.LayerNorm([C, H, W]) output = layer_norm(input)

    opened by youngtboy 3
  • How to achieve the grad-CAM visualization?

    How to achieve the grad-CAM visualization?

    Thanks for your awesome work and for sharing them all.

    I found out that the pictures in the supplement paper are beautiful, and I want to follow this.

    Could you share the code for this? or can tell me how to achieve the grad-CAM activation map?

    opened by DoranLyong 3
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 0
  • On the use of Apex AMP and hybrid stages

    On the use of Apex AMP and hybrid stages

    Is there a specific reason why you used Apex AMP instead of the native AMP provided by PyTorch? Have you tried native AMP?

    I tried to train poolformer_s12 and poolformer_s24 with solo-learn; with native fp16 the loss goes to nan after a few epochs, while with fp32 it works fine. Did you experience similar behavior?

    On a side note, can you provide the implementation and the hyperparameters for the hybrid stage [Pool, Pool, Attention, Attention]? It seems very interesting!

    discussion 
    opened by DonkeyShot21 6
  • Can I say PoolFormer is just a non-trainable MLP-like module?

    Can I say PoolFormer is just a non-trainable MLP-like module?

    Hi! Thanks for sharing the great work! I have some questions about PoolFormer. If I explain PoolFormer like the following attachments, can I say PoolFormer is just a non-trainable MLP-like model?

    image image

    discussion 
    opened by 072jiajia 8
  • About subtract in pooling

    About subtract in pooling

    Hi, thank you for publishing such a nice paper. I just have one question. I do not understand the subtraction of the input in eqn.4. Is it necessary? What will happen if we just do the average pooling without substrating the input?

    discussion 
    opened by Dong-Huo 16
Owner
Sea AI Lab
Sea AI Lab
Code and Data for NeurIPS2021 Paper "A Dataset for Answering Time-Sensitive Questions"

Time-Sensitive-QA The repo contains the dataset and code for NeurIPS2021 (dataset track) paper Time-Sensitive Question Answering dataset. The dataset

wenhu chen 35 Nov 14, 2022
The repository forked from NVlabs uses our data. (Differentiable rasterization applied to 3D model simplification tasks)

nvdiffmodeling [origin_code] Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Autom

Qiujie (Jay) Dong 2 Oct 31, 2022
Implementing DeepMind's Fast Reinforcement Learning paper

Fast Reinforcement Learning This is a repo where I implement the algorithms in the paper, Fast reinforcement learning with generalized policy updates.

Marcus Chiam 6 Nov 28, 2022
Machine Learning automation and tracking

The Open-Source MLOps Orchestration Framework MLRun is an open-source MLOps framework that offers an integrative approach to managing your machine-lea

873 Jan 04, 2023
Official implementation of "Implicit Neural Representations with Periodic Activation Functions"

Implicit Neural Representations with Periodic Activation Functions Project Page | Paper | Data Vincent Sitzmann*, Julien N. P. Martel*, Alexander W. B

Vincent Sitzmann 1.4k Jan 06, 2023
A baseline code for VSPW

A baseline code for VSPW Preparation Download VSPW dataset The VSPW dataset with extracted frames and masks is available here.

28 Aug 22, 2022
The source code of CVPR 2019 paper "Deep Exemplar-based Video Colorization".

Deep Exemplar-based Video Colorization (Pytorch Implementation) Paper | Pretrained Model | Youtube video 🔥 | Colab demo Deep Exemplar-based Video Col

Bo Zhang 253 Dec 27, 2022
SwinTrack: A Simple and Strong Baseline for Transformer Tracking

SwinTrack This is the official repo for SwinTrack. A Simple and Strong Baseline Prerequisites Environment conda (recommended) conda create -y -n SwinT

LitingLin 196 Jan 04, 2023
A list of all named GANs!

The GAN Zoo Every week, new GAN papers are coming out and it's hard to keep track of them all, not to mention the incredibly creative ways in which re

Avinash Hindupur 12.9k Jan 08, 2023
Meta-Learning Sparse Implicit Neural Representations (NeurIPS 2021)

Meta-SparseINR Official PyTorch implementation of "Meta-learning Sparse Implicit Neural Representations" (NeurIPS 2021) by Jaeho Lee*, Jihoon Tack*, N

Jaeho Lee 41 Nov 10, 2022
Low-dose Digital Mammography with Deep Learning

Impact of loss functions on the performance of a deep neural network designed to restore low-dose digital mammography ====== This repository contains

WANG-AXIS 6 Dec 13, 2022
An official implementation of "SFNet: Learning Object-aware Semantic Correspondence" (CVPR 2019, TPAMI 2020) in PyTorch.

PyTorch implementation of SFNet This is the implementation of the paper "SFNet: Learning Object-aware Semantic Correspondence". For more information,

CV Lab @ Yonsei University 87 Dec 30, 2022
A lightweight deep network for fast and accurate optical flow estimation.

FastFlowNet: A Lightweight Network for Fast Optical Flow Estimation The official PyTorch implementation of FastFlowNet (ICRA 2021). Authors: Lingtong

Tone 161 Jan 03, 2023
Medical image analysis framework merging ANTsPy and deep learning

ANTsPyNet A collection of deep learning architectures and applications ported to the python language and tools for basic medical image processing. Bas

Advanced Normalization Tools Ecosystem 118 Dec 24, 2022
Multi-Agent Reinforcement Learning (MARL) method to learn scalable control polices for multi-agent target tracking.

scalableMARL Scalable Reinforcement Learning Policies for Multi-Agent Control CD. Hsu, H. Jeong, GJ. Pappas, P. Chaudhari. "Scalable Reinforcement Lea

Christopher Hsu 17 Nov 17, 2022
Adaptive FNO transformer - official Pytorch implementation

Adaptive Fourier Neural Operators: Efficient Token Mixers for Transformers This repository contains PyTorch implementation of the Adaptive Fourier Neu

NVIDIA Research Projects 77 Dec 29, 2022
Python 3 module to print out long strings of text with intervals of time inbetween

Python-Fastprint Python 3 module to print out long strings of text with intervals of time inbetween Install: pip install fastprint Sync Usage: from fa

Kainoa Kanter 2 Jun 27, 2022
learning and feeling SLAM together with hands-on-experiments

modern-slam-tutorial-python Learning and feeling SLAM together with hands-on-experiments 😀 😃 😆 Dependencies Most of the examples are based on GTSAM

Giseop Kim 59 Dec 22, 2022
Контрольная работа по математическим методам машинного обучения

ML-MathMethods-Test Контрольная работа по математическим методам машинного обучения. Вычисление основных статистик, диаграмм и графиков, проверка разл

Stas Ivanovskii 1 Jan 06, 2022