A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

Masked Autoencoders Are Scalable Vision Learners

Open In Colab

A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementation of the proposed method is available in mae-pretraining.ipynb notebook. It includes evaluation with linear probing as well. Furthermore, the notebook can be fully executed on Google Colab. Our main objective is to present the core idea of the proposed method in a minimal and readable manner. We have also prepared a blog for getting started with Masked Autoencoder easily.


With just 100 epochs of pre-training and a fairly lightweight and asymmetric Autoencoder architecture we achieve 49.33%% accuracy with linear probing on the CIFAR-10 dataset. Our training logs and encoder weights are released in Weights and Logs. For comparison, we took the encoder architecture and trained it from scratch (refer to regular-classification.ipynb) in a fully supervised manner. This gave us ~76% test top-1 accuracy.

We note that with further hyperparameter tuning and more epochs of pre-training, we can achieve a better performance with linear-probing. Below we present some more results:

Config Masking
proportion
LP
performance
Encoder weights
& logs
Encoder & decoder layers: 3 & 1
Batch size: 256
0.6 44.25% Link
Do 0.75 46.84% Link
Encoder & decoder layers: 6 & 2
Batch size: 256
0.75 48.16% Link
Encoder & decoder layers: 9 & 3
Batch size: 256
Weight deacy: 1e-5
0.75 49.33% Link

LP denotes linear-probing. Config is mostly based on what we define in the hyperparameters section of this notebook: mae-pretraining.ipynb.

Acknowledgements

References

[1] Masked Autoencoders Are Scalable Vision Learners; He et al.; arXiv 2021; https://arxiv.org/abs/2111.06377.

You might also like...
A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Official implementation of the paper
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

SimMIM: A Simple Framework for Masked Image Modeling
SimMIM: A Simple Framework for Masked Image Modeling

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

SeMask: Semantically Masked Transformers for Semantic Segmentation.
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

FocusFace: Multi-task Contrastive Learning for Masked Face Recognition
FocusFace: Multi-task Contrastive Learning for Masked Face Recognition

FocusFace This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at IEEE International C

Comments
  • Excellent work (`mae.ipynb`)!

    Excellent work (`mae.ipynb`)!

    @ariG23498 this is fantastic stuff. Super clean, readable, and coherent with the original implementation. A couple of suggestions that would likely make things even better:

    • Since you have already implemented masking visualization utilities how about making them part of the PatchEncoder itself? That way you could let it accept a test image, apply random masking, and plot it just like the way you are doing in the earlier cells. This way I believe the notebook will be cleaner.
    • AdamW (tfa.optimizers.adamw) is a better choice when it comes to training Transformer-based models.
    • Are we taking the loss on the correct component? I remember you mentioning it being dealt with differently.

    After these points are addressed I will take a crack at porting the training loop to TPUs along with other performance monitoring callbacks.

    opened by sayakpaul 7
  • Unshuffle the patches?

    Unshuffle the patches?

    Your code helps me a lot! However, I still have some questions. In the paper, the authors say they unshuffle the full list before applying the deocder. In the MaskedAutoencoder class of your implementation, decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
    no unshuffling is used. I wonder if you can tell me the purpose of doing so? Thanks a lot!

    opened by changtaoli 2
  • Could you also share the weight of the pretrained decoder?

    Could you also share the weight of the pretrained decoder?

    Hi,

    Thanks for your excellent implementation! I found that you have shared the weights of the encoder, but if we want to replicate the reconstruction, the pretrained decoder is still needed. So, could you also share the weight of the pretrained decoder?

    Best Regards, Hongxin

    opened by hongxin001 1
  • Issue with the plotting utility `show_masked_image`

    Issue with the plotting utility `show_masked_image`

    Should be:

    def show_masked_image(self, patches):
            # Utility function that helps visualize maksed images.
            _, unmask_indices = self.get_random_indices()
            unmasked_patches = tf.gather(patches, unmask_indices, axis=1, batch_dims=1)
    
            # Necessary for plotting.
            ids = tf.argsort(unmask_indices)
            sorted_unmask_indices = tf.sort(unmask_indices)
            unmasked_patches = tf.gather(unmasked_patches, ids, batch_dims=1)
    
            # Select a random index for visualization.
            idx = np.random.choice(len(sorted_unmask_indices))
            print(f"Index selected: {idx}.")
    
            n = int(np.sqrt(NUM_PATCHES))
            unmask_index = sorted_unmask_indices[idx]
            unmasked_patch = unmasked_patches[idx]
    
            plt.figure(figsize=(4, 4))
    
            count = 0
            for i in range(NUM_PATCHES):
                ax = plt.subplot(n, n, i + 1)
    
                if count < unmask_index.shape[0] and unmask_index[count].numpy() == i:
                    patch = unmasked_patch[count]
                    patch_img = tf.reshape(patch, (PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
                    count = count + 1
                else:
                    patch_img = tf.zeros((PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
            plt.show()
    
            # Return the random index to validate the image outside the method.
            return idx
    
    opened by ariG23498 1
Releases(v1.0.0)
Owner
Aritra Roy Gosthipaty
Learning with a learning rate of 1e-10.
Aritra Roy Gosthipaty
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Jamie Townsend 42 Dec 12, 2022
Implementation of Deep Deterministic Policy Gradiet Algorithm in Tensorflow

ddpg-aigym Deep Deterministic Policy Gradient Implementation of Deep Deterministic Policy Gradiet Algorithm (Lillicrap et al.arXiv:1509.02971.) in Ten

Steven Spielberg P 247 Dec 07, 2022
Code to reproduce the experiments in the paper "Transformer Based Multi-Source Domain Adaptation" (EMNLP 2020)

Transformer Based Multi-Source Domain Adaptation Dustin Wright and Isabelle Augenstein To appear in EMNLP 2020. Read the preprint: https://arxiv.org/a

CopeNLU 36 Dec 05, 2022
implicit displacement field

Geometry-Consistent Neural Shape Representation with Implicit Displacement Fields [project page][paper][cite] Geometry-Consistent Neural Shape Represe

Yifan Wang 100 Dec 19, 2022
Visualizing lattice vibration information from phonon dispersion to atoms (For GPUMD)

Phonon-Vibration-Viewer (For GPUMD) Visualizing lattice vibration information from phonon dispersion for primitive atoms. In this tutorial, we will in

Liangting 6 Dec 10, 2022
Decensoring Hentai with Deep Neural Networks. Formerly named DeepMindBreak.

DeepCreamPy Decensoring Hentai with Deep Neural Networks. Formerly named DeepMindBreak. A deep learning-based tool to automatically replace censored a

616 Jan 06, 2023
VarCLR: Variable Semantic Representation Pre-training via Contrastive Learning

    VarCLR: Variable Representation Pre-training via Contrastive Learning New: Paper accepted by ICSE 2022. Preprint at arXiv! This repository contain

squaresLab 32 Oct 24, 2022
Applying PVT to Semantic Segmentation

Applying PVT to Semantic Segmentation Here, we take MMSegmentation v0.13.0 as an example, applying PVTv2 to SemanticFPN. For details see Pyramid Visio

35 Nov 30, 2022
Learning Pixel-level Semantic Affinity with Image-level Supervision for Weakly Supervised Semantic Segmentation, CVPR 2018

Learning Pixel-level Semantic Affinity with Image-level Supervision This code is deprecated. Please see https://github.com/jiwoon-ahn/irn instead. Int

Jiwoon Ahn 337 Dec 15, 2022
Code for the paper Hybrid Spectrogram and Waveform Source Separation

Demucs Music Source Separation This is the 3rd release of Demucs (v3), featuring hybrid source separation. For the waveform only Demucs (v2): Go this

Meta Research 4.8k Jan 04, 2023
MPI-IS Mesh Processing Library

Perceiving Systems Mesh Package This package contains core functions for manipulating meshes and visualizing them. It requires Python 3.5+ and is supp

Max Planck Institute for Intelligent Systems 494 Jan 06, 2023
LibMTL: A PyTorch Library for Multi-Task Learning

LibMTL LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and AP

765 Jan 06, 2023
[ICCV 2021] Counterfactual Attention Learning for Fine-Grained Visual Categorization and Re-identification

Counterfactual Attention Learning Created by Yongming Rao*, Guangyi Chen*, Jiwen Lu, Jie Zhou This repository contains PyTorch implementation for ICCV

Yongming Rao 90 Dec 31, 2022
Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking

Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking We revisit and address issues with Oxford 5k and Paris 6k image retrieval benchm

Filip Radenovic 188 Dec 17, 2022
(ICCV 2021 Oral) Re-distributing Biased Pseudo Labels for Semi-supervised Semantic Segmentation: A Baseline Investigation.

DARS Code release for the paper "Re-distributing Biased Pseudo Labels for Semi-supervised Semantic Segmentation: A Baseline Investigation", ICCV 2021

CVMI Lab 58 Jan 01, 2023
An end-to-end regression problem of predicting the price of properties in Bangalore.

Bangalore-House-Price-Prediction An end-to-end regression problem of predicting the price of properties in Bangalore. Deployed in Heroku using Flask.

Shruti Balan 1 Nov 25, 2022
This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

Quinn Herden 1 Feb 04, 2022
Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection

Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection Main requirements torch = 1.0 torchvision = 0.2.0 Python 3 Environm

15 Apr 04, 2022
Joint Gaussian Graphical Model Estimation: A Survey

Joint Gaussian Graphical Model Estimation: A Survey Test Models Fused graphical lasso [1] Group graphical lasso [1] Graphical lasso [1] Doubly joint s

Koyejo Lab 1 Aug 10, 2022
Retina blood vessel segmentation with a convolutional neural network

Retina blood vessel segmentation with a convolution neural network (U-net) This repository contains the implementation of a convolutional neural netwo

Orobix 1.2k Jan 06, 2023