Vision Transformer and MLP-Mixer Architectures


Vision Transformer and MLP-Mixer Architectures

Update (2.7.2021): Added the "When Vision Transformers Outperform ResNets..." paper, and SAM (Sharpness-Aware Minimization) optimized ViT and MLP-Mixer checkpoints.

Update (20.6.2021): Added the "How to train your ViT? ..." paper, and a new Colab to explore the >50k pre-trained and fine-tuned checkpoints mentioned in the paper.

Update (18.6.2021): This repository was rewritten to use Flax Linen API and ml_collections.ConfigDict for configuration.

In this repository we release models from the papers

The models were pre-trained on the ImageNet and ImageNet-21k datasets. We provide the code for fine-tuning the released models in JAX/Flax.

Table of contents:


Below Colabs run both with GPUs, and TPUs (8 cores, data parallelism).

The first Colab demonstrates the JAX code of Vision Transformers and MLP Mixers. This Colab allows you to edit the files from the repository directly in the Colab UI and has annotated Colab cells that walk you through the code step by step, and lets you interact with the data.

The second Colab allows you to explore the >50k Vision Transformer and hybrid checkpoints that were used to generate the data of the third paper "How to train your ViT? ...". The Colab includes code to explore and select checkpoints, and to do inference both using the JAX code from this repo, and also using the popular timm PyTorch library that can directly load these checkpoints as well.

The second Colab also lets you fine-tune the checkpoints on any tfds dataset and your own dataset with examples in individual JPEG files (optionally directly reading from Google Drive).

Note: As for now (6/20/21) Google Colab only supports a single GPU (Nvidia Tesla T4), and TPUs (currently TPUv2-8) are attached indirectly to the Colab VM and communicate over slow network, which leads to pretty bad training speed. You would usually want to set up a dedicated machine if you have a non-trivial amount of data to fine-tune on. For details see the Running on cloud section.


Make sure you have Python>=3.6 installed on your machine.

For installing JAX, follow the instructions provided in the corresponding repository linked here. Note that installation instructions for GPU differs slightly from the instructions for CPU.

Then, install python dependencies by running:

pip install -r vit_jax/requirements.txt

For more details refer to the section Running on cloud below.

Fine-tuning a model

You can run fine-tuning of the downloaded model on your dataset of interest. All models share the same command line interface.

For example for fine-tuning a ViT-B/16 (pre-trained on imagenet21k) on CIFAR10 (note how we specify b16,cifar10 as arguments to the config, and how we instruct the code to access the models directly from a GCS bucket instead of first downloading them into the local directory):

python -m vit_jax.main --workdir=/tmp/vit-$(date +%s) \
    --config=$(pwd)/vit_jax/configs/,cifar10 \

In order to fine-tune a Mixer-B/16 (pre-trained on imagenet21k) on CIFAR10:

python -m vit_jax.main --workdir=/tmp/vit-$(date +%s) \
    --config=$(pwd)/vit_jax/configs/ \

The "How to train your ViT? ..." paper added >50k checkpoints that you can fine-tune with the configs/ config. When you only specify the model name (the value from configs/, then the best i21k checkpoint by upstream validation accuracy ("recommended" checkpoint, see section 4.5 of the paper) is chosen. To make up your mind which model you want to use, have a look at Figure 3 in the paper. It's also possible to choose a different checkpoint (see Colab vit_jax_augreg.ipynb) and then specify the value from the filename or adapt_filename column, which correspond to the filenames without .npz from the gs://vit_models/augreg directory.

python -m vit_jax.main --workdir=/tmp/vit-$(date +%s) \
    --config=$(pwd)/vit_jax/configs/ \
    --config.dataset=oxford_iiit_pet \

Currently, the code will automatically download CIFAR-10 and CIFAR-100 datasets. Other public or custom datasets can be easily integrated, using tensorflow datasets library. Note that you will also need to update vit_jax/ to specify some parameters about any added dataset.

Note that our code uses all available GPUs/TPUs for fine-tuning.

To see a detailed list of all available flags, run python3 -m vit_jax.train --help.

Notes on memory:

  • Different models require different amount of memory. Available memory also depends on the accelerator configuration (both type and count). If you encounter an out-of-memory error you can increase the value of --config.accum_steps=8 -- alternatively, you could also decrease the --config.batch=512 (and decrease --config.base_lr accordingly).
  • The host keeps a shuffle buffer in memory. If you encounter a host OOM (as opposed to an accelerator OOM), you can decrease the default --config.shuffle_buffer=50000.

Vision Transformer

by Alexey Dosovitskiy*†, Lucas Beyer*, Alexander Kolesnikov*, Dirk Weissenborn*, Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit and Neil Houlsby*†.

(*) equal technical contribution, (†) equal advising.

Figure 1 from paper

Overview of the model: we split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. In order to perform classification, we use the standard approach of adding an extra learnable "classification token" to the sequence.

Available ViT models

We provide models pre-trained on ImageNet-21k for the following architectures: ViT-B/16, ViT-B/32, ViT-L/16 and ViT-L/32. We provide the same models pre-trained on ImageNet-21k and fine-tuned on ImageNet.

Update (29.7.2021): Added ViT-B/8 AugReg models (3 upstream checkpoints and adaptations with resolution=224).

Update (2.7.2021): We added the ViT models trained from scratch with SAM optimizer on ImageNet (with basic Inception-style preprocessing). The resultant ViTs outperform ResNets of similar size and throughput without large-scale pre-training or strong data augmentations. They also possess more perceptive attention maps. To use those models, you can simply replace the model path in vit_jax.ipynb with gs://vit_models/sam.

Update (19.5.2021): With publication of the "How to train your ViT? ..." paper, we added more than 50k ViT and hybrid models pre-trained on ImageNet and ImageNet-21k with various degrees of data augmentation and model regularization, and fine-tuned on ImageNet, Pets37, Kitti-distance, CIFAR-100, and Resisc45. Check out vit_jax_augreg.ipynb to navigate this treasure trove of models! For example, you can use that Colab to fetch the filenames of recommended pre-trained and fine-tuned checkpoints from the i21k_300 column of Table 3 in the paper:

Model Pre-trained checkpoint Size Fine-tuned checkpoint Resolution Img/sec Imagenet accuracy
L/16 gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz 1243 MiB gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz 384 50 85.59%
B/16 gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz 391 MiB gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz 384 138 85.49%
S/16 gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz 115 MiB gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz 384 300 83.73%
R50+L/32 gs://vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz 1337 MiB gs://vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz 384 327 85.99%
R26+S/32 gs://vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz 170 MiB gs://vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz 384 560 83.85%
Ti/16 gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz 37 MiB gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz 384 610 78.22%
B/32 gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz 398 MiB gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz 384 955 83.59%
S/32 gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz 118 MiB gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz 384 2154 79.58%
R+Ti/16 gs://vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz 40 MiB gs://vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz 384 2426 75.40%

Update (1.12.2020): We have added the R50+ViT-B/16 hybrid model (ViT-B/16 on top of a Resnet-50 backbone). When pretrained on imagenet21k, this model achieves almost the performance of the L/16 model with less than half the computational finetuning cost. Note that "R50" is somewhat modified for the B/16 variant: The original ResNet-50 has [3,4,6,3] blocks, each reducing the resolution of the image by a factor of two. In combination with the ResNet stem this would result in a reduction of 32x so even with a patch size of (1,1) the ViT-B/16 variant cannot be realized anymore. For this reason we instead use [3,4,9] blocks for the R50+B/16 variant.

Update (9.11.2020): We have also added the ViT-L/16 model.

Update (29.10.2020): We have added ViT-B/16 and ViT-L/16 models pretrained on ImageNet-21k and then fine-tuned on ImageNet at 224x224 resolution (instead of default 384x384). These models have the suffix "-224" in their name. They are expected to achieve 81.2% and 82.7% top-1 accuracies respectively.

You can find all these models in the following storage bucket:

For example, if you would like to download the ViT-B/16 pre-trained on imagenet21k run the following command:


Expected ViT results

Table below runs experiments both with transformer.dropout_rate=0.1 (as in the ViT paper), and with transformer.dropout_rate=0.0, which improves results somewhat for models B=16, B/32, and L/32. The better setting was chosen for the default config of the models in this repository. Note also that all these models have representation_size=None, i.e. the last layer before the classification layer is dropped for fine-tuning.

model dataset dropout=0.0 dropout=0.1
R50+ViT-B_16 cifar10 98.72%, 3.9h (A100), 98.94%, 10.1h (V100),
R50+ViT-B_16 cifar100 90.88%, 4.1h (A100), 92.30%, 10.1h (V100),
R50+ViT-B_16 imagenet2012 83.72%, 9.9h (A100), 85.08%, 24.2h (V100),
ViT-B_16 cifar10 99.02%, 2.2h (A100), 98.76%, 7.8h (V100),
ViT-B_16 cifar100 92.06%, 2.2h (A100), 91.92%, 7.8h (V100),
ViT-B_16 imagenet2012 84.53%, 6.5h (A100), 84.12%, 19.3h (V100),
ViT-B_32 cifar10 98.88%, 0.8h (A100), 98.75%, 1.8h (V100),
ViT-B_32 cifar100 92.31%, 0.8h (A100), 92.05%, 1.8h (V100),
ViT-B_32 imagenet2012 81.66%, 3.3h (A100), 81.31%, 4.9h (V100),
ViT-L_16 cifar10 99.13%, 6.9h (A100), 99.14%, 24.7h (V100),
ViT-L_16 cifar100 92.91%, 7.1h (A100), 93.22%, 24.4h (V100),
ViT-L_16 imagenet2012 84.47%, 16.8h (A100), 85.05%, 59.7h (V100),
ViT-L_32 cifar10 99.06%, 1.9h (A100), 99.09%, 6.1h (V100),
ViT-L_32 cifar100 93.29%, 1.9h (A100), 93.34%, 6.2h (V100),
ViT-L_32 imagenet2012 81.89%, 7.5h (A100), 81.13%, 15.0h (V100),

We also would like to emphasize that high-quality results can be achieved with shorter training schedules and encourage users of our code to play with hyper-parameters to trade-off accuracy and computational budget. Some examples for CIFAR-10/100 datasets are presented in the table below.

upstream model dataset total_steps / warmup_steps accuracy wall-clock time link
imagenet21k ViT-B_16 cifar10 500 / 50 98.59% 17m
imagenet21k ViT-B_16 cifar10 1000 / 100 98.86% 39m
imagenet21k ViT-B_16 cifar100 500 / 50 89.17% 17m
imagenet21k ViT-B_16 cifar100 1000 / 100 91.15% 39m


by Ilya Tolstikhin*, Neil Houlsby*, Alexander Kolesnikov*, Lucas Beyer*, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner, Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy.

(*) equal contribution.

Figure 1 from paper

MLP-Mixer (Mixer for short) consists of per-patch linear embeddings, Mixer layers, and a classifier head. Mixer layers contain one token-mixing MLP and one channel-mixing MLP, each consisting of two fully-connected layers and a GELU nonlinearity. Other components include: skip-connections, dropout, and linear classifier head.

For installation follow the same steps as above.

Available Mixer models

Update (2.7.2021): We added the MLP-Mixer models trained with SAM on ImageNet without strong augmentations (gs://mixer_models/sam). The loss landscapes become much smoother, and we found that the activated neurons for the first few layers decrease dramatically after SAM, indicating the potential redundency of image patches.

We provide the Mixer-B/16 and Mixer-L/16 models pre-trained on the ImageNet and ImageNet-21k datasets. Details can be found in Table 3 of the Mixer paper. All the models can be found at:

Expected Mixer results

We ran the fine-tuning code on Google Cloud machine with four V100 GPUs with the default adaption parameters from this repository. Here are the results:

upstream model dataset accuracy wall_clock_time link
ImageNet Mixer-B/16 cifar10 96.72% 3.0h
ImageNet Mixer-L/16 cifar10 96.59% 3.0h
ImageNet-21k Mixer-B/16 cifar10 96.82% 9.6h
ImageNet-21k Mixer-L/16 cifar10 98.34% 10.0h

Running on cloud

While above colabs are pretty useful to get started, you would usually want to train on a larger machine with more powerful accelerators.

Create a VM

You can use the following commands to setup a VM with GPUs on Google Cloud:

# Set variables used by all commands below.
# Note that project must have accounting set up.
# For a list of zones with GPUs refer to
PROJECT=my-awesome-gcp-project  # Project must have billing enabled.

# Below settings have been tested with this repository. You can choose other
# combinations of images & machines (e.g.), refer to the corresponding gcloud commands:
# gcloud compute images list --project ml-images
# gcloud compute machine-types list
# etc.
gcloud compute instances create $VM_NAME \
    --project=$PROJECT --zone=$ZONE \
    --image=c1-deeplearning-tf-2-5-cu110-v20210527-debian-10 \
    --image-project=ml-images --machine-type=n1-standard-96 \
    --scopes=cloud-platform,storage-full --boot-disk-size=256GB \
    --boot-disk-type=pd-ssd --metadata=install-nvidia-driver=True \
    --maintenance-policy=TERMINATE \

# Connect to VM (after some minutes needed to setup & start the machine).
gcloud compute ssh --project $PROJECT --zone $ZONE $VM_NAME

# Stop the VM after use (only storage is billed for a stopped VM).
gcloud compute instances stop --project $PROJECT --zone $ZONE $VM_NAME

# Delete VM after use (this will also remove all data stored on VM).
gcloud compute instances delete --project $PROJECT --zone $ZONE $VM_NAME

Alternatively, you can use the following similar commands to set up a Cloud VM with TPUs attached to them (below commands copied from the TPU tutorial):

PROJECT=my-awesome-gcp-project  # Project must have billing enabled.

# Required to set up service identity initially.
gcloud beta services identity create --service

# Create a VM with TPUs directly attached to it.
gcloud alpha compute tpus tpu-vm create $VM_NAME \
    --project=$PROJECT --zone=$ZONE \
    --accelerator-type v3-8 \
    --version v2-alpha

# Connect to VM (after some minutes needed to setup & start the machine).
gcloud alpha compute tpus tpu-vm ssh --project $PROJECT --zone $ZONE $VM_NAME

# Stop the VM after use (only storage is billed for a stopped VM).
gcloud alpha compute tpus tpu-vm stop --project $PROJECT --zone $ZONE $VM_NAME

# Delete VM after use (this will also remove all data stored on VM).
gcloud alpha compute tpus tpu-vm delete --project $PROJECT --zone $ZONE $VM_NAME

Setup VM

And then fetch the repository and the install dependencies (including jaxlib with TPU support) as usual:

git clone --depth=1 --branch=master
cd vision_transformer
pip3 install virtualenv
python3 -m virtualenv env
. env/bin/activate

If you're connected to a VM with GPUs attached, install JAX with the following command:

pip3 install --upgrade jax jaxlib \

If you're connected to a VM with TPUs attached, install JAX with the following command:

pip3 install --upgrade jax jaxlib

For both GPUs and TPUs, then proceed to install the remaining dependencies and check that accelerators can indeed show up in JAX:

pip install -r vit_jax/requirements.txt
# Check that JAX can connect to attached accelerators:
python -c 'import jax; print(jax.devices())'

And finally execute one of the commands mentioned in the section fine-tuning a model.


  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and  Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},

  title={MLP-Mixer: An all-MLP Architecture for Vision},
  author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, Thomas and Yung, Jessica and Steiner, Andreas and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
  journal={arXiv preprint arXiv:2105.01601},

  title={How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers},
  author={Steiner, Andreas and Kolesnikov, Alexander and and Zhai, Xiaohua and Wightman, Ross and Uszkoreit, Jakob and Beyer, Lucas},
  journal={arXiv preprint arXiv:2106.10270},

  title={When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations}, 
  author={Chen, Xiangning and Hsieh, Cho-Jui and Gong, Boqing},
  journal={arXiv preprint arXiv:2106.01548},


Open source release prepared by Andreas Steiner.

Note: This repository was forked and modified from google-research/big_transfer.

This is not an official Google product.

  • Non-finetuned ViT-B_16 has final layer weights set to 0

    Non-finetuned ViT-B_16 has final layer weights set to 0


    I am trying to use the non-finetuned ViT-B_16 and it seems the final layer weights are all set to 0:

    import numpy as np
    params = np.load('imagenet21k_ViT-B_16.npz')
    keys, values = zip(*list(params.items()))
    >>> values[keys.index('head/kernel')]
    array([[0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)
    >>> values[keys.index('head/bias')]
    array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

    The other non-finetuned model weights do not seem to have this problem. Could this checkpoint please be updated?

    Thank you, Rohan

    opened by rtaori 14
  • Hybrid pretrained models

    Hybrid pretrained models

    Thanks for providing the code and pretrained models! I was curious if Hybrid models would also be released; if so, any idea on the timeline? (days? weeks?)

    opened by skrish13 12
  • Hyper Params & loss function for ViT-L/16

    Hyper Params & loss function for ViT-L/16

    Hello! I was trying to pretrain the ViT-L/16 architecture on imagenet-21k though there are a few aspects which I'm unsure of and hoping you can clarify.

    1. What were the hyperparameters used for pretraining. In the paper it's mentioned that the model was trained for either 30/90 epochs and 0.1 weight decay at the start and 0.03 in the appendix, hence the uncertainty.

    2. Is the data augmentation and model initialization the same as what's used in the released jax code?

    3. What loss function was used? considering that imagenet-21k and JFT-300M are multilabel datasets.


    opened by owmohamm 10
  • Has anyone reproduce the fine tune results?

    Has anyone reproduce the fine tune results?

    Thanks to the authors for this amazing work.

    I'm currently reproducing the fine-tune part (pre-train on ImageNet 21k and fine-tune on ImageNet 1k). However, the best result I can get is 84.1% with ViT-L16.

    I noticed the authors kindly provided tensorboard for our reference. It shows ViT-L16 achieves 83.25% as early as the 800th step, while I can only get around 80% in step 2503 (we are running on epoch-based, and there should be 2503 steps per epoch for a batch size of 512). In addition, it took 2 hours for the authors to reach step 800, while in our experiments, it takes around 11 minutes.

    Our setup is 64 V100 32G, the total batch size is 512, no accum step is applied. The rest are all following table B.1.1 of the original paper except that we train for 8 epochs which makes the total training steps 20.02k

    opened by ZhiyuanChen 9
  • Hyper-parameters of ViT-B/16 training from scratch

    Hyper-parameters of ViT-B/16 training from scratch

    Thanks for sharing your code. Can you provide the hyper-parameters (e.g. learning rate, weight decay, optimizer type, training epochs) of ViT-B/16 training from scratch on ImageNet dataset? Many thanks.

    opened by liuyuyuil 9
  • Can I download your SAM ViTs checkpoints?

    Can I download your SAM ViTs checkpoints?

    Thank you for your exciting and promising works! I notice that you have recently release the SAM pretrained ViTs, I'm also very much interesting in SAM's effect on vision transformers, but have not enough GPUs or TPU. I'm wondering if and how I'm able to download it and use in my own pytorch code? Thank you for your time and contribution. You can email me via [email protected]

    opened by Longday0923 7
  • Pretrained weights and position embedding at 224x224

    Pretrained weights and position embedding at 224x224

    I've been fiddling with the models a bit, massaging the weights into my PyTorch impl. I have the base 384x384 models working well, but generating the params for the 224x224, the resulting output has low validation accuracy (it was in the 76s top-1 when I killed it).

    Is that expected? Is the pos embedding interpolation from 24x24 grid only intended as a starting point for transfer learning and not expected to provide good results as is? By comparison the base 384x384 16x16 patch is validating at 84.2 and the 384x384 32x32 patch at 81.7

    opened by rwightman 7
  • About pretrain head of vit

    About pretrain head of vit

    In the paper, it says:

    The classification head is implemented by a MLP with one hidden layer at pre-training time and by a single linear layer at fine-tuning time.

    In the code, it seems only the finetuning code shows up.

    What is the parameters (hidden size, etc.) of the head to pretrain vit?

    opened by cissoidx 4
  • Difference between the two models on the results?

    Difference between the two models on the results?

    What is the difference between the first row and the second row of the table in expected result? I expect the resulution size to be 224x224 or 384x384. Is that correct?

    opened by jeonsworld 4
  • Inconsistent shapes between value and initializer for parameter

    Inconsistent shapes between value and initializer for parameter "scale" in "/gn_root" in Vision Transformer AugReg

    I'm trying to run the fine-tuning model in the second Colab, but got the following error when applying parameters to the model.

    UnfilteredStackTrace                      Traceback (most recent call last)
    <ipython-input-47-9ae0da31d080> in <module>()
          1 # Inferance on batch with single example.
    ----> 2 logits, = model.apply({'params': params}, [pp(d['image'], resolution)], train=False)
    18 frames
    UnfilteredStackTrace: flax.errors.ScopeParamShapeError: Inconsistent shapes between value and initializer for parameter "scale" in "/gn_root": (1, 1, 1, 64), (64,). (
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    The above exception was the direct cause of the following exception:
    ScopeParamShapeError                      Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/flax/linen/ in _normalize(mdl, x, mean, var, reduction_axes, feature_axes, dtype, epsilon, use_bias, use_scale, bias_init, scale_init)
         94   mul = lax.rsqrt(var + epsilon)
         95   if use_scale:
    ---> 96     scale = mdl.param('scale', scale_init, reduced_feature_shape).reshape(feature_shape)
         97     mul *= scale
         98   y *= mul
    ScopeParamShapeError: Inconsistent shapes between value and initializer for parameter "scale" in "/gn_root": (1, 1, 1, 64), (64,). (
    opened by topsy404 3
  • MLPMixer number of parameters

    MLPMixer number of parameters

    I re-implemented MLPMixer in TensorFlow 2.4 and got different parameters than reported in the paper:

    S/32: 18.0M S/16: 18.0M B/32: 59.5M B/16: 59.1M L/32: 205.9M L/16: 207.2M H/14: 431.1M

    import tensorflow as tf
    from tensorflow.keras import layers
    from tensorflow.keras.backend import int_shape
    def MLP(x, dim):
        y = layers.Dense(dim)(x)
        y = tf.nn.gelu(y)
        return layers.Dense(int_shape(x)[-1])(y)
    def MixerLayer(x, tokens_mlp_dim, channels_mlp_dim):
        y = layers.LayerNormalization()(x)
        y = tf.transpose(y, perm=[0, 2, 1])
        y = MLP(y, tokens_mlp_dim)
        y = tf.transpose(y, perm=[0, 2, 1])
        x = x + y
        y = layers.LayerNormalization()(x)
        return x + MLP(y, channels_mlp_dim)
    def MLPMixer(
        assert image_size % patch_size == 0
        input = layers.Input(shape=(image_size, image_size, 3))
        x = layers.Conv2D(hidden_dim, patch_size, patch_size)(input)
        x = layers.Reshape(target_shape=((image_size // patch_size)**2, hidden_dim))(x)
        for _ in range(num_blocks):
            x = MixerLayer(x, tokens_mlp_dim, channels_mlp_dim)
        x = layers.LayerNormalization()(x)
        x = layers.GlobalAvgPool1D()(x)
        # x = layers.Dense(num_classes, kernel_initializer=tf.keras.initializers.zeros)(x)
        model = tf.keras.models.Model(inputs=input, outputs=x)
        return model
    if __name__ == '__main__':
       S32 = MLPMixer()
       S16 = MLPMixer(patch_size=16)
       B32 = MLPMixer(num_blocks=12, patch_size=32, hidden_dim=768, tokens_mlp_dim=384, channels_mlp_dim=3072)
       B16 = MLPMixer(num_blocks=12, patch_size=16, hidden_dim=768, tokens_mlp_dim=384, channels_mlp_dim=3072)
       L32 = MLPMixer(num_blocks=24, patch_size=32, hidden_dim=1024, tokens_mlp_dim=512, channels_mlp_dim=4096)
       L16 = MLPMixer(num_blocks=24, patch_size=16, hidden_dim=1024, tokens_mlp_dim=512, channels_mlp_dim=4096)
       H14 = MLPMixer(num_blocks=32, patch_size=14, hidden_dim=1280, tokens_mlp_dim=640, channels_mlp_dim=5120)
       print('S/32: {:.1f}M'.format(S32.count_params() / 1e6))
       print('S/16: {:.1f}M'.format(S16.count_params() / 1e6))
       print('B/32: {:.1f}M'.format(B32.count_params() / 1e6))
       print('B/16: {:.1f}M'.format(B16.count_params() / 1e6))
       print('L/32: {:.1f}M'.format(L32.count_params() / 1e6))
       print('L/16: {:.1f}M'.format(L16.count_params() / 1e6))
       print('H/14: {:.1f}M'.format(H14.count_params() / 1e6))
    opened by wmcnally 3
  • GSAM checkpoints don't appear to be valid .npz files

    GSAM checkpoints don't appear to be valid .npz files

    I can't load the .npz GSAM checkpoints in numpy.

    Also, if/when you get a chance to fix, would it be possible to name the files differently from the SAM npz files in the case where checkpoints are cached in a common folder? ie ViT-B_16-GSAM.npz would be preferred over both SAM and GSAM checkpoints being ViT-B_16.npz


    opened by rwightman 2
  • Github page for the project is down

    Github page for the project is down

    It seems that the github pages website of this project is down. By clicking on the following link : we are getting a 404 error.

    opened by peterbonnesoeur 0
  • Convert fine tuned checkpoint on custom dataset to similar pretrained .npz file

    Convert fine tuned checkpoint on custom dataset to similar pretrained .npz file

    I was wondering if it was possible to convert a checkpoint fine tuned with a custom dataset to a .npz similar to the pretrained checkpoints?

    Or is there another way to save the model with the fine tuned checkpoint, and to load that fine tuned mode for inference? I have tried using the flax_checkpoints.restore_checkpoint() to load the fine tuned checkpoint. But I'm not sure how to retrieve the parameter dictionary that is used by the model.apply function similar to inference examples in the Google collabs.

    Thank you very much in advance!

    opened by jwhg 0
  • Image segmentation or image classification ?

    Image segmentation or image classification ?

    I just got a chance to go over the details from this repository. I initially realized that this Github's jupyter notebook's initial purpose was to assist users in conducting image classification instead of image segmentation based on this link. I am not sure whether I understood this link's contents in right track or correct manner. But my initial understanding was that this Github perhaps was not designed for the purpose of image segmentation exactly. Can anyone clarify on this ?

    opened by WEIYI2021 0
  • Google Colab notebook:

    Google Colab notebook: "Vision Transformer AugReg" imports not working (it used to work month ago)

    Hello, I found a bug that I cannot fix by myself for some time now, so I'm opening this issue.

    Reproduction instructions: When using this google colab notebook from readme:


    I'm not changing anything, just executing code. And when importing some libraries, I'm getting such errors: (I remember that around month ago this was working for me)


    When I install "tensorflow_text" library, then a new error is shown:


    Can this be fixed so the notebook would be executable? Thank you

    opened by anneyxa 1
Google Research
Google Research
Unofficial implementation of MLP-Mixer: An all-MLP Architecture for Vision

MLP-Mixer: An all-MLP Architecture for Vision This repo contains PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision. Usage : impo

Rishikesh (ऋषिकेश) 164 Jun 20, 2022
Implements MLP-Mixer: An all-MLP Architecture for Vision.

MLP-Mixer-CIFAR10 This repository implements MLP-Mixer as proposed in MLP-Mixer: An all-MLP Architecture for Vision. The paper introduces an all MLP (

Sayak Paul 49 May 15, 2022
Implementation for paper MLP-Mixer: An all-MLP Architecture for Vision

MLP Mixer Implementation for paper MLP-Mixer: An all-MLP Architecture for Vision. Give us a star if you like this repo. Author: Github: bangoc123 Emai

Ngoc Nguyen Ba 85 Jun 8, 2022
Official codebase used to develop Vision Transformer, MLP-Mixer, LiT and more.

Big Vision This codebase is designed for training large-scale vision models on Cloud TPU VMs. It is based on Jax/Flax libraries, and uses and

Google Research 474 Jun 11, 2022
This is an official implementation for "AS-MLP: An Axial Shifted MLP Architecture for Vision".

AS-MLP architecture for Image Classification Model Zoo Image Classification on ImageNet-1K Network Resolution Top-1 (%) Params FLOPs Throughput (image

SVIP Lab 93 Jun 13, 2022
PyTorch implementation of MLP-Mixer

PyTorch implementation of MLP-Mixer MLP-Mixer: an all-MLP architecture composed of alternate token-mixing and channel-mixing operations. The token-mix

Duo Li 33 Jun 1, 2022
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Rishabh Anand 24 Mar 23, 2022
Keras attention models including botnet,CoaT,CoAtNet,CMT,cotnet,halonet,resnest,resnext,resnetd,volo,mlp-mixer,resmlp,gmlp,levit

Keras_cv_attention_models Keras_cv_attention_models Usage Basic Usage Layers Model surgery AotNet ResNetD ResNeXt ResNetQ BotNet VOLO ResNeSt HaloNet

null 194 Jun 10, 2022
Unofficial Implementation of MLP-Mixer, Image Classification Model

MLP-Mixer Unoffical Implementation of MLP-Mixer, easy to use with terminal. Train and test easly. MLP-Mixer is an arc

Oğuzhan Ercan 5 Jan 14, 2022
Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

Qiushi Yang 1 Jan 13, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 10.5k Jun 20, 2022
This repository builds a basic vision transformer from scratch so that one beginner can understand the theory of vision transformer.

vision-transformer-from-scratch This repository includes several kinds of vision transformers from scratch so that one beginner can understand the the

null 1 Dec 24, 2021
An All-MLP solution for Vision, from Google AI

MLP Mixer - Pytorch An All-MLP solution for Vision, from Google AI, in Pytorch. No convolutions nor attention needed! Yannic Kilcher video Install $ p

Phil Wang 699 Jun 13, 2022
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Qibin (Andrew) Hou 148 Jun 13, 2022
PaddleViT: State-of-the-art Visual Transformer and MLP Models for PaddlePaddle 2.0+

PaddlePaddle Vision Transformers State-of-the-art Visual Transformer and MLP Models for PaddlePaddle ?? PaddlePaddle Visual Transformers (PaddleViT or

null 894 Jun 17, 2022
Alex Pashevich 51 Jun 2, 2022
Multi-Scale Vision Longformer: A New Vision Transformer for High-Resolution Image Encoding

Vision Longformer This project provides the source code for the vision longformer paper. Multi-Scale Vision Longformer: A New Vision Transformer for H

Microsoft 186 Jun 20, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.2k Jun 23, 2022
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

null 82 Jun 10, 2022