A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

Masked Autoencoders Are Scalable Vision Learners

Open In Colab

A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementation of the proposed method is available in mae-pretraining.ipynb notebook. It includes evaluation with linear probing as well. Furthermore, the notebook can be fully executed on Google Colab. Our main objective is to present the core idea of the proposed method in a minimal and readable manner. We have also prepared a blog for getting started with Masked Autoencoder easily.


With just 100 epochs of pre-training and a fairly lightweight and asymmetric Autoencoder architecture we achieve 49.33%% accuracy with linear probing on the CIFAR-10 dataset. Our training logs and encoder weights are released in Weights and Logs. For comparison, we took the encoder architecture and trained it from scratch (refer to regular-classification.ipynb) in a fully supervised manner. This gave us ~76% test top-1 accuracy.

We note that with further hyperparameter tuning and more epochs of pre-training, we can achieve a better performance with linear-probing. Below we present some more results:

Config Masking
proportion
LP
performance
Encoder weights
& logs
Encoder & decoder layers: 3 & 1
Batch size: 256
0.6 44.25% Link
Do 0.75 46.84% Link
Encoder & decoder layers: 6 & 2
Batch size: 256
0.75 48.16% Link
Encoder & decoder layers: 9 & 3
Batch size: 256
Weight deacy: 1e-5
0.75 49.33% Link

LP denotes linear-probing. Config is mostly based on what we define in the hyperparameters section of this notebook: mae-pretraining.ipynb.

Acknowledgements

References

[1] Masked Autoencoders Are Scalable Vision Learners; He et al.; arXiv 2021; https://arxiv.org/abs/2111.06377.

You might also like...
A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Official implementation of the paper
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

SimMIM: A Simple Framework for Masked Image Modeling
SimMIM: A Simple Framework for Masked Image Modeling

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

SeMask: Semantically Masked Transformers for Semantic Segmentation.
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

FocusFace: Multi-task Contrastive Learning for Masked Face Recognition
FocusFace: Multi-task Contrastive Learning for Masked Face Recognition

FocusFace This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at IEEE International C

Comments
  • Excellent work (`mae.ipynb`)!

    Excellent work (`mae.ipynb`)!

    @ariG23498 this is fantastic stuff. Super clean, readable, and coherent with the original implementation. A couple of suggestions that would likely make things even better:

    • Since you have already implemented masking visualization utilities how about making them part of the PatchEncoder itself? That way you could let it accept a test image, apply random masking, and plot it just like the way you are doing in the earlier cells. This way I believe the notebook will be cleaner.
    • AdamW (tfa.optimizers.adamw) is a better choice when it comes to training Transformer-based models.
    • Are we taking the loss on the correct component? I remember you mentioning it being dealt with differently.

    After these points are addressed I will take a crack at porting the training loop to TPUs along with other performance monitoring callbacks.

    opened by sayakpaul 7
  • Unshuffle the patches?

    Unshuffle the patches?

    Your code helps me a lot! However, I still have some questions. In the paper, the authors say they unshuffle the full list before applying the deocder. In the MaskedAutoencoder class of your implementation, decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
    no unshuffling is used. I wonder if you can tell me the purpose of doing so? Thanks a lot!

    opened by changtaoli 2
  • Could you also share the weight of the pretrained decoder?

    Could you also share the weight of the pretrained decoder?

    Hi,

    Thanks for your excellent implementation! I found that you have shared the weights of the encoder, but if we want to replicate the reconstruction, the pretrained decoder is still needed. So, could you also share the weight of the pretrained decoder?

    Best Regards, Hongxin

    opened by hongxin001 1
  • Issue with the plotting utility `show_masked_image`

    Issue with the plotting utility `show_masked_image`

    Should be:

    def show_masked_image(self, patches):
            # Utility function that helps visualize maksed images.
            _, unmask_indices = self.get_random_indices()
            unmasked_patches = tf.gather(patches, unmask_indices, axis=1, batch_dims=1)
    
            # Necessary for plotting.
            ids = tf.argsort(unmask_indices)
            sorted_unmask_indices = tf.sort(unmask_indices)
            unmasked_patches = tf.gather(unmasked_patches, ids, batch_dims=1)
    
            # Select a random index for visualization.
            idx = np.random.choice(len(sorted_unmask_indices))
            print(f"Index selected: {idx}.")
    
            n = int(np.sqrt(NUM_PATCHES))
            unmask_index = sorted_unmask_indices[idx]
            unmasked_patch = unmasked_patches[idx]
    
            plt.figure(figsize=(4, 4))
    
            count = 0
            for i in range(NUM_PATCHES):
                ax = plt.subplot(n, n, i + 1)
    
                if count < unmask_index.shape[0] and unmask_index[count].numpy() == i:
                    patch = unmasked_patch[count]
                    patch_img = tf.reshape(patch, (PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
                    count = count + 1
                else:
                    patch_img = tf.zeros((PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
            plt.show()
    
            # Return the random index to validate the image outside the method.
            return idx
    
    opened by ariG23498 1
Releases(v1.0.0)
Owner
Aritra Roy Gosthipaty
Learning with a learning rate of 1e-10.
Aritra Roy Gosthipaty
Codes for the ICCV'21 paper "FREE: Feature Refinement for Generalized Zero-Shot Learning"

FREE This repository contains the reference code for the paper "FREE: Feature Refinement for Generalized Zero-Shot Learning". [arXiv][Paper] 1. Prepar

Shiming Chen 28 Jul 29, 2022
Image De-raining Using a Conditional Generative Adversarial Network

Image De-raining Using a Conditional Generative Adversarial Network [Paper Link] [Project Page] He Zhang, Vishwanath Sindagi, Vishal M. Patel In this

He Zhang 216 Dec 18, 2022
Interactive dimensionality reduction for large datasets

BlosSOM 🌼 BlosSOM is a graphical environment for running semi-supervised dimensionality reduction with EmbedSOM. You can use it to explore multidimen

19 Dec 14, 2022
YOLOX-CondInst - Implement CondInst which is a instances segmentation method on YOLOX

YOLOX CondInst -- YOLOX 实例分割 前言 本项目是自己学习实例分割时,复现的代码. 通过自己编程,让自己对实例分割有更进一步的了解。 若想

DDGRCF 16 Nov 18, 2022
A framework for using LSTMs to detect anomalies in multivariate time series data. Includes spacecraft anomaly data and experiments from the Mars Science Laboratory and SMAP missions.

Telemanom (v2.0) v2.0 updates: Vectorized operations via numpy Object-oriented restructure, improved organization Merge branches into single branch fo

Kyle Hundman 844 Dec 28, 2022
WTTE-RNN a framework for churn and time to event prediction

WTTE-RNN Weibull Time To Event Recurrent Neural Network A less hacky machine-learning framework for churn- and time to event prediction. Forecasting p

Egil Martinsson 727 Dec 28, 2022
Pytorch implementation for A-NeRF: Articulated Neural Radiance Fields for Learning Human Shape, Appearance, and Pose

A-NeRF: Articulated Neural Radiance Fields for Learning Human Shape, Appearance, and Pose Paper | Website | Data A-NeRF: Articulated Neural Radiance F

Shih-Yang Su 172 Dec 22, 2022
EquiBind: Geometric Deep Learning for Drug Binding Structure Prediction

EquiBind: geometric deep learning for fast predictions of the 3D structure in which a small molecule binds to a protein

Hannes Stärk 355 Jan 03, 2023
DCGAN-tensorflow - A tensorflow implementation of Deep Convolutional Generative Adversarial Networks

DCGAN in Tensorflow Tensorflow implementation of Deep Convolutional Generative Adversarial Networks which is a stabilize Generative Adversarial Networ

Taehoon Kim 7.1k Dec 29, 2022
Code for Multiple Instance Active Learning for Object Detection, CVPR 2021

MI-AOD Language: 简体中文 | English Introduction This is the code for Multiple Instance Active Learning for Object Detection (The PDF is not available tem

Tianning Yuan 269 Dec 21, 2022
Final project for Intro to CS class.

Financial Analysis Web App https://share.streamlit.io/mayurk1/fin-web-app-final-project/webApp.py 1. Project Description This project is a technical a

Mayur Khanna 1 Dec 10, 2021
Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

causal-bald | Abstract | Installation | Example | Citation | Reproducing Results DUE An implementation of the methods presented in Causal-BALD: Deep B

OATML 13 Oct 07, 2022
Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Segmentation Transformer Implementation of Segmentation Transformer in PyTorch, a new model to achieve SOTA in semantic segmentation while using trans

Abhay Gupta 161 Dec 08, 2022
Deep learning with dynamic computation graphs in TensorFlow

TensorFlow Fold TensorFlow Fold is a library for creating TensorFlow models that consume structured data, where the structure of the computation graph

1.8k Dec 28, 2022
Gradient representations in ReLU networks as similarity functions

Gradient representations in ReLU networks as similarity functions by Dániel Rácz and Bálint Daróczy. This repo contains the python code related to our

1 Oct 08, 2021
Computer-Vision-Paper-Reviews - Computer Vision Paper Reviews with Key Summary along Papers & Codes

Computer-Vision-Paper-Reviews Computer Vision Paper Reviews with Key Summary along Papers & Codes. Jonathan Choi 2021 50+ Papers across Computer Visio

Jonathan Choi 2 Mar 17, 2022
PuppetGAN - Cross-Domain Feature Disentanglement and Manipulation just got way better! 🚀

Better Cross-Domain Feature Disentanglement and Manipulation with Improved PuppetGAN Quite cool... Right? Introduction This repo contains a TensorFlow

Giorgos Karantonis 5 Aug 25, 2022
This is the winning solution of the Endocv-2021 grand challange.

Endocv2021-winner [Paper] This is the winning solution of the Endocv-2021 grand challange. Dependencies pytorch # tested with 1.7 and 1.8 torchvision

Vajira Thambawita 14 Dec 03, 2022
StrongSORT: Make DeepSORT Great Again

StrongSORT StrongSORT: Make DeepSORT Great Again StrongSORT: Make DeepSORT Great Again Yunhao Du, Yang Song, Bo Yang, Yanyun Zhao arxiv 2202.13514 Abs

369 Jan 04, 2023
Sparse Physics-based and Interpretable Neural Networks

Sparse Physics-based and Interpretable Neural Networks for PDEs This repository contains the code and manuscript for research done on Sparse Physics-b

28 Jan 03, 2023