An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

An pytorch implementation of Masked Autoencoders Are Scalable Vision Learners

This is a coarse version for MAE, only make the pretrain model, the finetune and linear is comming soon.

1. Introduction

This repo is the MAE-vit model which impelement with pytorch, no reference any reference code so this is a non-official version. Because of the limitation of time and machine, I only trained the vit-tiny model for encoder. mae

2. Enveriments

  • python 3.7+
  • pytorch 1.7.1
  • pillow
  • timm
  • opencv-python

3. Model Config

Pretrain Config

  • BaseConfig
    img_size = 224,
    patch_size = 16,
  • Encoder The encoder if follow the Vit-tiny model config
    encoder_dim = 192,
    encoder_depth = 12,
    encoder_heads = 3,
  • Decoder The decoder is followed the kaiming paper config.
    decoder_dim = 512,
    decoder_depth = 8,
    decoder_heads = 16, 
  • Mask
    1. We use the shuffle patch after Sin-Cos position embeeding for encoder.
    2. Mask the shuffle patch, keep the mask index.
    3. Unshuffle the mask patch and combine with the encoder embeeding before the position embeeding for decoder.
    4. Restruction decoder embeeidng by convtranspose.
    5. Build the mask map with mask index for cal the loss(only consider the mask patch).

Finetune Config

Wait for the results

TODO:

  • Finetune Trainig
  • Linear Training

4. Results

decoder Restruction the imagenet validation image from pretrain model, compare with the kaiming results, restruction quality is less than he. May be the encoder model is too small TT.

The Mae-Vit-tiny pretrain models is here, you can download to test the restruction result. Put the ckpt in weights folder.

5. Training & Inference

  • dataset prepare

    /data/home/imagenet/xxx.jpeg, 0
    /data/home/imagenet/xxx.jpeg, 1
    ...
    /data/home/imagenet/xxx.jpeg, 999
    
  • Training

    1. Pretrain

      #!/bin/bash
      OMP_NUM_THREADS=1
      MKL_NUM_THREADS=1
      export OMP_NUM_THREADS
      export MKL_NUM_THREADS
      cd MAE-Pytorch;
      CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch --nproc_per_node 8 train_mae.py \
      --batch_size 256 \
      --num_workers 32 \
      --lr 1.5e-4 \
      --optimizer_name "adamw" \
      --cosine 1 \
      --max_epochs 300 \
      --warmup_epochs 40 \
      --num-classes 1000 \
      --crop_size 224 \
      --patch_size 16 \
      --color_prob 0.0 \
      --calculate_val 0 \
      --weight_decay 5e-2 \
      --lars 0 \
      --mixup 0.0 \
      --smoothing 0.0 \
      --train_file $train_file \
      --val_file $val_file \
      --checkpoints-path $ckpt_folder \
      --log-dir $log_folder
    2. Finetune TODO:

      • training
    3. Linear TODO:

      • training
  • Inference

    1. pretrian
    python mae_test.py --test_image xxx.jpg --ckpt weights.pth
    1. classification TODO:
      • training

6. TODO

  • VIT-BASE model training.
  • SwinTransformers for MAE.
  • Finetune & Linear training.

Finetune is trainig, the weights may be comming soon.

Owner
FlyEgle
JOYY AI GROUP - Machine Learning Engineer(Computer Vision)
FlyEgle
This repo is a C++ version of yolov5_deepsort_tensorrt. Packing all C++ programs into .so files, using Python script to call C++ programs further.

yolov5_deepsort_tensorrt_cpp Introduction This repo is a C++ version of yolov5_deepsort_tensorrt. And packing all C++ programs into .so files, using P

41 Dec 27, 2022
A dataset for online Arabic calligraphy

Calliar Calliar is a dataset for Arabic calligraphy. The dataset consists of 2500 json files that contain strokes manually annotated for Arabic callig

ARBML 114 Dec 28, 2022
The official re-implementation of the Neurips 2021 paper, "Targeted Neural Dynamical Modeling".

Targeted Neural Dynamical Modeling Note: This is a re-implementation (in Tensorflow2) of the original TNDM model. We do not plan to further update the

6 Oct 05, 2022
Unbiased Learning To Rank Algorithms (ULTRA)

This is an Unbiased Learning To Rank Algorithms (ULTRA) toolbox, which provides a codebase for experiments and research on learning to rank with human annotated or noisy labels.

71 Dec 01, 2022
Western-3DSlicer-Modules - Point-Set Registrations for Ultrasound Probe Calibrations

Point-Set Registrations for Ultrasound Probe Calibrations -Undergraduate Thesis-

Matteo Tanzi 0 May 04, 2022
Code repository for our paper "Learning to Generate Scene Graph from Natural Language Supervision" in ICCV 2021

Scene Graph Generation from Natural Language Supervision This repository includes the Pytorch code for our paper "Learning to Generate Scene Graph fro

Yiwu Zhong 64 Dec 24, 2022
TLXZoo - Pre-trained models based on TensorLayerX

Pre-trained models based on TensorLayerX. TensorLayerX is a multi-backend AI fra

TensorLayer Community 13 Dec 07, 2022
PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020).

Scaffold-Federated-Learning PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020). Environment numpy=

KI 30 Dec 29, 2022
Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Database

Python cx_Oracle Notebooks, 2022 The repository contains Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Da

Christopher Jones 13 Dec 15, 2022
Official code repository for the publication "Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons"

Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons This repository contains the code to repr

Computational Neuroscience, University of Bern 3 Aug 04, 2022
AbelNN: Deep Learning Python module from scratch

AbelNN: Deep Learning Python module from scratch I have implemented several neural networks from scratch using only Numpy. I have designed the module

Abel 2 Apr 12, 2022
Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive Learning".

ERICA Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive L

THUNLP 75 Nov 02, 2022
A python tutorial on bayesian modeling techniques (PyMC3)

Bayesian Modelling in Python Welcome to "Bayesian Modelling in Python" - a tutorial for those interested in learning how to apply bayesian modelling t

Mark Regan 2.4k Jan 06, 2023
Model Quantization Benchmark

Introduction MQBench is an open-source model quantization toolkit based on PyTorch fx. The envision of MQBench is to provide: SOTA Algorithms. With MQ

500 Jan 06, 2023
The fastest way to visualize GradCAM with your Keras models.

VizGradCAM VizGradCam is the fastest way to visualize GradCAM in Keras models. GradCAM helps with providing visual explainability of trained models an

58 Nov 19, 2022
Source codes for Improved Few-Shot Visual Classification (CVPR 2020), Enhancing Few-Shot Image Classification with Unlabelled Examples

Source codes for Improved Few-Shot Visual Classification (CVPR 2020), Enhancing Few-Shot Image Classification with Unlabelled Examples (WACV 2022) and Beyond Simple Meta-Learning: Multi-Purpose Model

PLAI Group at UBC 42 Dec 06, 2022
PyTorch implementation for "Mining Latent Structures with Contrastive Modality Fusion for Multimedia Recommendation"

MIRCO PyTorch implementation for paper: Latent Structures Mining with Contrastive Modality Fusion for Multimedia Recommendation Dependencies Python 3.

Big Data and Multi-modal Computing Group, CRIPAC 9 Dec 08, 2022
Convert Table data to approximate values with GUI

Table_Editor Convert Table data to approximate values with GUIs... usage - Import methods for extension Tables. Imported method supposed to have only

CLJ 1 Jan 10, 2022
Medical-Image-Triage-and-Classification-System-Based-on-COVID-19-CT-and-X-ray-Scan-Dataset

Medical-Image-Triage-and-Classification-System-Based-on-COVID-19-CT-and-X-ray-Sc

2 Dec 26, 2021
Code for a real-time distributed cooperative slam(RDC-SLAM) system for ROS compatible platforms.

RDC-SLAM This repository contains code for a real-time distributed cooperative slam(RDC-SLAM) system for ROS compatible platforms. The system takes in

40 Nov 19, 2022