This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

Related tags

Deep LearningMesa
Overview

Mesa: A Memory-saving Training Framework for Transformers

This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for Transformers.

By Zizheng Pan, Peng Chen, Haoyu He, Jing Liu, Jianfei Cai and Bohan Zhuang.

image-20211116105242785

Installation

  1. Create a virtual environment with anaconda.

    conda create -n mesa python=3.7 -y
    conda activate mesa
    
    # Install PyTorch, we use PyTorch 1.7.1 with CUDA 10.1 
    pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
    
    # Install ninja
    pip install ninja
  2. Build and install Mesa.

    # cloen this repo
    git clone https://github.com/zhuang-group/Mesa
    # build
    cd Mesa/
    # You need to have an NVIDIA GPU
    python setup.py develop

Usage

  1. Prepare your policy and save as a text file, e.g. policy.txt.

    on gelu: # layer tag, choices: fc, conv, gelu, bn, relu, softmax, matmul, layernorm
        by_index: all # layer index
        enable: True # enable for compressing
        level: 256 # we adopt 8-bit quantization by default
        ema_decay: 0.9 # the decay rate for running estimates
        
        by_index: 1 2 # e.g. exluding GELU layers that indexed by 1 and 2.
        enable: False
  2. Next, you can wrap your model with Mesa by:

    import mesa as ms
    ms.policy.convert_by_num_groups(model, 3)
    # or convert by group size with ms.policy.convert_by_group_size(model, 64)
    
    # setup compression policy
    ms.policy.deploy_on_init(model, '[path to policy.txt]', verbose=print, override_verbose=False)

    That's all you need to use Mesa for memory saving.

    Note that convert_by_num_groups and convert_by_group_size only recognize nn.XXX, if your code has functional operations, such as [email protected] and F.Softmax, you may need to manually setup these layers. For example:

    # matrix multipcation (before)
    out = Q@K.transpose(-2, -1)
    # with Mesa
    self.mm = ms.MatMul(quant_groups=3)
    out = self.mm(q, k.transpose(-2, -1))
    
    # sofmtax (before)
    attn = attn.softmax(dim=-1)
    # with Mesa
    self.softmax = ms.Softmax(dim=-1, quant_groups=3)
    attn = self.softmax(attn)
  3. You can also target one layer by:

    import mesa as ms
    # previous 
    self.act = nn.GELU()
    # with Mesa
    self.act = ms.GELU(quant_groups=[num of quantization groups])

Demo projects for DeiT and Swin

We provide demo projects to replicate our results of training DeiT and Swin with Mesa, please refer to DeiT-Mesa and Swin-Mesa.

Results on ImageNet

Model Param (M) FLOPs (G) Train Memory Top-1 (%)
DeiT-Ti 5 1.3 4,171 71.9
DeiT-Ti w/ Mesa 5 1.3 1,858 72.1
DeiT-S 22 4.6 8,459 79.8
DeiT-S w/ Mesa 22 4.6 3,840 80.0
DeiT-B 86 17.5 17,691 81.8
DeiT-B w/ Mesa 86 17.5 8,616 81.8
Swin-Ti 29 4.5 11,812 81.3
Swin-Ti w/ Mesa 29 4.5 5,371 81.3
PVT-Ti 13 1.9 7,800 75.1
PVT-Ti w/ Mesa 13 1.9 3,782 74.9

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgments

This repository has adopted part of the quantization codes from ActNN, we thank the authors for their open-sourced code.

Owner
Zhuang AI Group
Zhuang AI Group
Self-Supervised Contrastive Learning of Music Spectrograms

Self-Supervised Music Analysis Self-Supervised Contrastive Learning of Music Spectrograms Dataset Songs on the Billboard Year End Hot 100 were collect

27 Dec 10, 2022
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 06, 2022
The MLOps platform for innovators 🚀

​ DS2.ai is an integrated AI operation solution that supports all stages from custom AI development to deployment. It is an AI-specialized platform service that collects data, builds a training datas

9 Jan 03, 2023
Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip)

Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip) Introduction TL;DR: We propose an efficient and trainabl

17 Dec 01, 2022
Permute Me Softly: Learning Soft Permutations for Graph Representations

Permute Me Softly: Learning Soft Permutations for Graph Representations

Giannis Nikolentzos 7 Jul 10, 2022
some classic model used to segment the medical images like CT、X-ray and so on

github_project This is a project for medical image segmentation. This project includes common medical image segmentation models such as U-net, FCN, De

2 Mar 30, 2022
This GitHub repository contains code used for plots in NeurIPS 2021 paper 'Stochastic Multi-Armed Bandits with Control Variates.'

About Repository This repository contains code used for plots in NeurIPS 2021 paper 'Stochastic Multi-Armed Bandits with Control Variates.' About Code

Arun Verma 1 Nov 09, 2021
Differentiable Abundance Matching With Python

shamnet Differentiable Stellar Population Synthesis Installation You can install shamnet with pip. Installation dependencies are numpy, jax, corrfunc,

5 Dec 17, 2021
Code for NeurIPS 2021 paper "Curriculum Offline Imitation Learning"

README The code is based on the ILswiss. To run the code, use python run_experiment.py --nosrun -e your YAML file -g gpu id Generally, run_experim

ApexRL 12 Mar 19, 2022
PyTorch implementation of SampleRNN: An Unconditional End-to-End Neural Audio Generation Model

samplernn-pytorch A PyTorch implementation of SampleRNN: An Unconditional End-to-End Neural Audio Generation Model. It's based on the reference implem

DeepSound 261 Dec 14, 2022
A collection of implementations of deep domain adaptation algorithms

Deep Transfer Learning on PyTorch This is a PyTorch library for deep transfer learning. We divide the code into two aspects: Single-source Unsupervise

Yongchun Zhu 647 Jan 03, 2023
This repo provides the base code for pytorch-lightning and weight and biases simultaneous integration.

Write your model faster with pytorch-lightning-wadb-code-backbone This repository provides the base code for pytorch-lightning and weight and biases s

9 Mar 29, 2022
Deeper insights into graph convolutional networks for semi-supervised learning

deeper_insights_into_GCNs Deeper insights into graph convolutional networks for semi-supervised learning References data and utils.py come from Implem

Davidham3 17 Dec 16, 2022
Docker containers of baseline agents for the Crafter environment

Crafter Baselines This repository contains Docker containers for running various baselines on the Crafter environment. Reward Agents DreamerV2 based o

Danijar Hafner 17 Sep 25, 2022
[ACM MM 2021] TSA-Net: Tube Self-Attention Network for Action Quality Assessment

Tube Self-Attention Network (TSA-Net) This repository contains the PyTorch implementation for paper TSA-Net: Tube Self-Attention Network for Action Qu

ShunliWang 18 Dec 23, 2022
Source code to accompany Defunctland's video "FASTPASS: A Complicated Legacy"

Shapeland Simulator Source code to accompany Defunctland's video "FASTPASS: A Complicated Legacy" Download the video at https://www.youtube.com/watch?

TouringPlans.com 70 Dec 14, 2022
Stacked Recurrent Hourglass Network for Stereo Matching

SRH-Net: Stacked Recurrent Hourglass Introduction This repository is supplementary material of our RA-L submission, which helps reviewers to understan

28 Jan 03, 2023
PyTorch Implementation of ByteDance's Cross-speaker Emotion Transfer Based on Speaker Condition Layer Normalization and Semi-Supervised Training in Text-To-Speech

Cross-Speaker-Emotion-Transfer - PyTorch Implementation PyTorch Implementation of ByteDance's Cross-speaker Emotion Transfer Based on Speaker Conditio

Keon Lee 114 Jan 08, 2023
XViT - Space-time Mixing Attention for Video Transformer

XViT - Space-time Mixing Attention for Video Transformer This is the official implementation of the XViT paper: @inproceedings{bulat2021space, title

Adrian Bulat 33 Dec 23, 2022
Pytorch implementation for the EMNLP 2020 (Findings) paper: Connecting the Dots: A Knowledgeable Path Generator for Commonsense Question Answering

Path-Generator-QA This is a Pytorch implementation for the EMNLP 2020 (Findings) paper: Connecting the Dots: A Knowledgeable Path Generator for Common

Peifeng Wang 33 Dec 05, 2022