Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

Overview

MobileViT

RegNet

Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER.


Table of Contents


Model Architecture

Trulli

MobileViT Architecture

Usage

Training

python main.py
optional arguments:
  -h, --help            show this help message and exit
  --gpu_device GPU_DEVICE
                        Select specific GPU to run the model
  --batch-size N        Input batch size for training (default: 64)
  --epochs N            Number of epochs to train (default: 20)
  --num-class N         Number of classes to classify (default: 10)
  --lr LR               Learning rate (default: 0.01)
  --weight-decay WD     Weight decay (default: 1e-5)
  --model-path PATH     Path to save the model

Citation

@InProceedings{Sachin2021,
  title = {MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER},
  author = {Sachin Mehta and Mohammad Rastegari},
  booktitle = {},
  year = {2021}
}

If this implement have any problem please let me know, thank you.

Comments
  • Training settings

    Training settings

    I really appreciate your efforts in implementing this model in pytorch. Here, I have one concern about the training settings. If what I understand is correct, you just trained the model for less than 5 epoches.

    In addition, the hyper-parameters you adopted is different from that in the original article. For instance, in the original manuscript, authors train mobilevit using AdamW optimizer, label smoothing cross-entry and multi-scale sampler. The training phase has a warmup stage.

    I also found that the classificaion accuracy provided here is much lower than that in the original version.

    I conjecture that the gab between accuracies are caused by different training settings.

    opened by hkzhang91 6
  • load pretrain weight failed

    load pretrain weight failed

    import torch
    import models
    
    model = models.MobileViT_S()
    PATH = "./MobileVit-S.pth.tar"
    weights = torch.load(PATH, map_location=lambda storage, loc: storage)
    model.load_state_dict(weights['state_dict'])
    model.eval()
    torch.save(model, './model.pt')
    
    • I try to load the pre-train weight to test one demo; but the network structure does not seem to match the weights, is there any solution?

    image

    opened by hererookie 2
  • model training hyperparameter

    model training hyperparameter

    A problem has been bothering me. the learning rate, optimizer, batch_size, L2 regularization, label smoothing and epochs are inconsistent with the paper. How should I modify the code?

    opened by Agino-ltp 1
  • Have you test MobileVit on cifar-10?

    Have you test MobileVit on cifar-10?

    Thanks for your wonderful work!

    I prepare to try MobileVit on small dataset, such as MNIST, and I need adjust the network structure. Before this work, I want to know if MobileVit has a better performance than other networks on small dataset.

    I notice "get_cifar10_dataset" in utils.py. Have you tested MobileVit on cifar-10? If you have, could you please show me the accuracy and inference time result?

    opened by Jerryme-xxm 1
  • Issues when loading MobileViT_S()

    Issues when loading MobileViT_S()

    I wanted to load the MobileViT_S() model and use the pre-trained weights, but I have got some errors in my code. To make it easier and help others, I will share my solution (in case there will be someone who is beginner like me):

    def load_mobilevit_weights(model_path):
      # Create an instance of the MobileViT model
      net = MobileViT_S()
      
      # Load the PyTorch state_dict
      state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
      
      # Since there is a problem in the names of layers, we will change the keys to meet the MobileViT model architecture
      for key in list(state_dict.keys()):
        state_dict[key.replace('module.', '')] = state_dict.pop(key)
      
      # Once the keys are fixed, we can modify the parameters of MobileViT
      net.load_state_dict(state_dict)
      
      return net
    
    net = load_mobilevit_weights("MobileViT_S_model_best.pth.tar")
    
    opened by Sehaba95 4
Releases(weight)
Owner
Hong-Jia Chen
Master student at National Chung Cheng University, Taiwan. Interested in Deep Learning and Computer Vision.
Hong-Jia Chen
Pytorch implementation of MixNMatch

MixNMatch: Multifactor Disentanglement and Encoding for Conditional Image Generation [Paper] Yuheng Li, Krishna Kumar Singh, Utkarsh Ojha, Yong Jae Le

910 Dec 30, 2022
PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

MAE for Self-supervised ViT Introduction This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-sup

36 Oct 30, 2022
SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers

SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers This repo contains our codes for the paper "No Parameters Left Behind: Sensitivity Gu

Chen Liang 23 Nov 07, 2022
Like Dirt-Samples, but cleaned up

Clean-Samples Like Dirt-Samples, but cleaned up, with clear provenance and license info (generally a permissive creative commons licence but check the

TidalCycles 39 Nov 30, 2022
Combinatorial model of ligand-receptor binding

Combinatorial model of ligand-receptor binding The binding of ligands to receptors is the starting point for many import signal pathways within a cell

Mobolaji Williams 0 Jan 09, 2022
Deep Learning Slide Captcha

滑动验证码深度学习识别 本项目使用深度学习 YOLOV3 模型来识别滑动验证码缺口,基于 https://github.com/eriklindernoren/PyTorch-YOLOv3 修改。 只需要几百张缺口标注图片即可训练出精度高的识别模型,识别效果样例: 克隆项目 运行命令: git cl

Python3WebSpider 55 Jan 02, 2023
Codes and models for the paper "Learning Unknown from Correlations: Graph Neural Network for Inter-novel-protein Interaction Prediction".

GNN_PPI Codes and models for the paper "Learning Unknown from Correlations: Graph Neural Network for Inter-novel-protein Interaction Prediction". Lear

Ursa Zrimsek 2 Dec 14, 2022
Autonomous Movement from Simultaneous Localization and Mapping

Autonomous Movement from Simultaneous Localization and Mapping About us Built by a group of Clarkson University students with the help from Professor

14 Nov 07, 2022
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
Spatial color quantization in Rust

rscolorq Rust port of Derrick Coetzee's scolorq, based on the 1998 paper "On spatial quantization of color images" by Jan Puzicha, Markus Held, Jens K

Collyn O'Kane 37 Dec 22, 2022
An implementation of the methods presented in Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

An implementation of the methods presented in Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

Andrew Jesson 9 Apr 04, 2022
Supplementary code for the AISTATS 2021 paper "Matern Gaussian Processes on Graphs".

Matern Gaussian Processes on Graphs This repo provides an extension for gpflow with Matérn kernels, inducing variables and trainable models implemente

41 Dec 17, 2022
The official codes for the ICCV2021 presentation "Uniformity in Heterogeneity: Diving Deep into Count Interval Partition for Crowd Counting"

UEPNet (ICCV2021 Poster Presentation) This repository contains codes for the official implementation in PyTorch of UEPNet as described in Uniformity i

Tencent YouTu Research 15 Dec 14, 2022
On the model-based stochastic value gradient for continuous reinforcement learning

On the model-based stochastic value gradient for continuous reinforcement learning This repository is by Brandon Amos, Samuel Stanton, Denis Yarats, a

Facebook Research 46 Dec 15, 2022
Implementation of the paper titled "Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees"

Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees Implementation of the paper titled "Using Sampling to

MIDAS, IIIT Delhi 2 Aug 29, 2022
Code for the Paper "Diffusion Models for Handwriting Generation"

Code for the Paper "Diffusion Models for Handwriting Generation"

62 Dec 21, 2022
A booklet on machine learning systems design with exercises

Machine Learning Systems Design Read this booklet here. This booklet covers four main steps of designing a machine learning system: Project setup Data

Chip Huyen 7.6k Jan 08, 2023
Genpass - A Passwors Generator App With Python3

Genpass Welcom again into another python3 App this is simply an Passwors Generat

Mal4D 1 Jan 09, 2022
Unofficial pytorch implementation for Self-critical Sequence Training for Image Captioning. and others.

An Image Captioning codebase This is a codebase for image captioning research. It supports: Self critical training from Self-critical Sequence Trainin

Ruotian(RT) Luo 906 Jan 03, 2023
DGL-TreeSearch and the Gurobi-MWIS interface

Independent Set Benchmarking Suite This repository contains the code for our maximum independent set benchmarking suite as well as our implementations

Maximilian Böther 19 Nov 22, 2022