Pytorch implementation of Masked Auto-Encoder

Related tags

Deep LearningMAE-code
Overview

Masked Auto-Encoder (MAE)

Pytorch implementation of Masked Auto-Encoder:

Usage

  1. Clone to the local.
> git clone https://github.com/liujiyuan13/MAE-code.git MAE-code
  1. Install required packages.
> cd MAE-code
> pip install requirements.txt
  1. Prepare datasets.
  • For Cifar10, Cifar100 and STL, skip this step for it will be done automatically;
  • For ImageNet1K, download and unzip the train(val) set into ./data/ImageNet1K/train(val).
  1. Set parameters.
  • All parameters are kept in default_args() function of main_mae(eval).py file.
  1. Run the code.
> python main_mae.py	# train MAE encoder
> python main_eval.py	# evaluate MAE encoder
  1. Visualize the ouput.
> tensorboard --logdir=./log --port 8888

Detail

Project structure

...
+ ckpt				# checkpoint
+ data 				# data folder
+ img 				# store images for README.md
+ log 				# log files
.gitignore 			
lars.py 			# LARS optimizer
main_eval.py 			# main file for evaluation
main_mae.py  			# main file for MAE training
model.py 			# model definitions of MAE and EvalNet
README.md 
util.py 			# helper functions
vit.py 				# definition of vision transformer

Encoder setting

In the paper, ViT-Base, ViT-Large and ViT-Huge are used. You can switch between them by simply changing the parameters in default_args(). Details can be found here and are listed in following table.

Name Layer Num. Hidden Size MLP Size Head Num.
Arg vit_depth vit_dim vit_mlp_dim vit_heads
ViT-B 12 768 3072 12
ViT-L 24 1024 4096 16
ViT-H 32 1280 5120 16

Evaluation setting

I implement four network training strategies concerned in the paper, including

  • pre-training is used to train MAE encoder and done in main_mae.py.
  • linear probing is used to evaluate MAE encoder. During training, MAE encoder is fixed.
    • args.n_partial = 0
  • partial fine-tuning is used to evaluate MAE encoder. During training, MAE encoder is partially fixed.
    • args.n_partial = 0.5 --> fine-tuning MLP sub-block with the transformer fixed
    • 1<=args.n_partial<=args.vit_depth-1 --> fine-tuning MLP sub-block and last layers of transformer
  • end-to-end fine-tuning is used to evaluate MAE encoder. During training, MAE encoder is fully trainable.
    • args.n_partial = args.vit_depth

Note that the last three strategies are done in main_eval.py where parameter args.n_partial is located.

At the same time, I follow the parameter settings in the paper appendix. Note that partial fine-tuning and end-to-end fine-tuning use the same setting. Nevertheless, I replace RandAug(9, 0.5) with RandomResizedCrop and leave mixup, cutmix and drop path techniques in further implementation.

Result

The experiment reproduce will takes a long time and I am unfortunately busy these days. If you get some results and are willing to contribute, please reach me via email. Thanks!

By the way, I have run the code from start to end. It works! So don't worry about the implementation errors. If you find any, please raise issues or email me.

Licence

This repository is under GPL V3.

About

Thanks project vit-pytorch, pytorch-lars and DeepLearningExamples for their codes contribute to this repository a lot!

Homepage: https://liujiyuan13.github.io

Email: [email protected]

Owner
Jiyuan
Jiyuan
DyStyle: Dynamic Neural Network for Multi-Attribute-Conditioned Style Editing

DyStyle: Dynamic Neural Network for Multi-Attribute-Conditioned Style Editing Figure: Joint multi-attribute edits using DyStyle model. Great diversity

74 Dec 03, 2022
The official implementation code of "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction."

PlantStereo This is the official implementation code for the paper "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction".

Wang Qingyu 14 Nov 28, 2022
PyTorch implementation of "ContextNet: Improving Convolutional Neural Networks for Automatic Speech Recognition with Global Context" (INTERSPEECH 2020)

ContextNet ContextNet has CNN-RNN-transducer architecture and features a fully convolutional encoder that incorporates global context information into

Sangchun Ha 24 Nov 24, 2022
Deep Illuminator is a data augmentation tool designed for image relighting. It can be used to easily and efficiently generate a wide range of illumination variants of a single image.

Deep Illuminator Deep Illuminator is a data augmentation tool designed for image relighting. It can be used to easily and efficiently generate a wide

George Chogovadze 52 Nov 29, 2022
Swapping face using Face Mesh with TensorFlow Lite

Swapping face using Face Mesh with TensorFlow Lite

iwatake 17 Apr 26, 2022
Instant-nerf-pytorch - NeRF trained SUPER FAST in pytorch

instant-nerf-pytorch This is WORK IN PROGRESS, please feel free to contribute vi

94 Nov 22, 2022
Using multidimensional LSTM neural networks to create a forecast for Bitcoin price

Multidimensional LSTM BitCoin Time Series Using multidimensional LSTM neural networks to create a forecast for Bitcoin price. For notes around this co

Jakob Aungiers 318 Dec 14, 2022
Evaluation Pipeline for our ECCV2020: Journey Towards Tiny Perceptual Super-Resolution.

Journey Towards Tiny Perceptual Super-Resolution Test code for our ECCV2020 paper: https://arxiv.org/abs/2007.04356 Our x4 upscaling pre-trained model

Royson 6 Mar 30, 2022
A Flow-based Generative Network for Speech Synthesis

WaveGlow: a Flow-based Generative Network for Speech Synthesis Ryan Prenger, Rafael Valle, and Bryan Catanzaro In our recent paper, we propose WaveGlo

NVIDIA Corporation 2k Dec 26, 2022
ISNAS-DIP: Image Specific Neural Architecture Search for Deep Image Prior [CVPR 2022]

ISNAS-DIP: Image-Specific Neural Architecture Search for Deep Image Prior (CVPR 2022) Metin Ersin Arican*, Ozgur Kara*, Gustav Bredell, Ender Konukogl

Özgür Kara 24 Dec 18, 2022
PyTorch 1.0 inference in C++ on Windows10 platforms

Serving PyTorch Models in C++ on Windows10 platforms How to use Prepare Data examples/data/train/ - 0 - 1 . . . - n examples/data/test/

Henson 88 Oct 15, 2022
Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

HamasKhan 3 Jul 08, 2022
Source code of the paper Meta-learning with an Adaptive Task Scheduler.

ATS About Source code of the paper Meta-learning with an Adaptive Task Scheduler. If you find this repository useful in your research, please cite the

Huaxiu Yao 16 Dec 26, 2022
CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

Frederick Wang 3 Apr 26, 2022
Platform-agnostic AI Framework 🔥

🇬🇧 TensorLayerX is a multi-backend AI framework, which can run on almost all operation systems and AI hardwares, and support hybrid-framework progra

TensorLayer Community 171 Jan 06, 2023
💛 Code and Dataset for our EMNLP 2021 paper: "Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes"

Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes Official PyTorch implementation and EmoCause evaluatio

Hyunwoo Kim 51 Jan 06, 2023
This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems.

Amortized Assimilation This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems. Abstract: T

4 Aug 16, 2022
Lolviz - A simple Python data-structure visualization tool for lists of lists, lists, dictionaries; primarily for use in Jupyter notebooks / presentations

lolviz By Terence Parr. See Explained.ai for more stuff. A very nice looking javascript lolviz port with improvements by Adnan M.Sagar. A simple Pytho

Terence Parr 785 Dec 30, 2022
Pipeline code for Sequential-GAM(Genome Architecture Mapping).

Sequential-GAM Pipeline code for Sequential-GAM(Genome Architecture Mapping). mapping whole_preprocess.sh include the whole processing of mapping. usa

3 Nov 03, 2022
To Design and Implement Logistic Regression to Classify Between Benign and Malignant Cancer Types

To Design and Implement Logistic Regression to Classify Between Benign and Malignant Cancer Types, from a Database Taken From Dr. Wolberg reports his Clinic Cases.

Astitva Veer Garg 1 Jul 31, 2022