PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

Related tags

Deep Learningpytorch
Overview

PyTorch-LIT

PyPI version

PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

With the rapid growth of deep learning research, models are becoming increasingly complex in terms of parameters and complexity, making it difficult to run the models on currently available end devices. For example, GPT-J with 6B parameters only needs 24 GB of RAM in full-precision mode to be ready for execution, which may be impossible in most systems; even a powerful GPU like the RTX 2060 with 6 GB of memory can't even contain GPT-J in half-precision mode, making direct inference impossible.

To address this issue when training large models, libraries such as DeepSpeed use offload techniques (e.g., ZeRO) to handle the parameters and make training possible by dividing the weights between devices. In contrast, there is no direct library/framework available for inference.

PyTorch-LIT allows the inference of large models by loading weights as needed from secondary specified memory, which could be disk, CPU, or GPU, allowing the inference of models that do not even fit in the system's main memory simply by trading off time.

Quick Start

  1. Install the library
pip install pytorch-lit
  1. You have to save the model's weight in a way that toolkit can use
from pytorch_lit.export import prepare_params

weights = {} # your model's parameters (state_dict)
# change the directory to save your model and specify data-type
prepare_params(weights, ".models/my-model", dtype="float32")
  1. After preparing the weights, you can infer your model
from pytorch_lit import LitModule

# pass your model construction as a closure, 
# specify weights path and inference device 
model = LitModule.from_params(".models/my-model",
                                  lambda: MyModel(),
                                  device="cuda")
result = model(*arg, **kwargs)
  1. Have fun enjoying the inference of the large model on a lower memory device:)

Examples

The repo's examples directory contains examples. There are currently two examples of GPT-J, one for text generation and the other for extracting hidden states as feature representations.

Development

This is a work in progress that will require further development before it can be considered a stable inference toolkit. Here is a list of potential future developments:

  • Caching and batch loading as many weights as memory allows, with weights being replaced in parallel with future ones (through the order of the execution graph)
  • C++ extension for PyTorch jit, so the solution applies to the majority of production end devices
  • Add functions to make it easier to export large models to onnx or trace with jit
  • Use better and faster format than numpy memmap

Contributions are welcome; to discuss your idea further, open an issue with the discussion tag. Finally, you can submit a pull request to merge your fork.

How does it work?

This implementation was made possible primarily by two ideas:

  • The first issue was that PyTorch initialized the model object's parameters when constructing it, causing the construction to fail when the model couldn't fit into memory. To address this, we proposed temporarily hijacking PyTorch's Parameter class's __new__ method during model construction, allowing us to replace the parameter's tensor with a view from a shared global tensor immediately after creation. By doing so, all parameters use the same shared big tensor as their primary storage, allowing the model to be built and tested with inputs to follow and trace the execution graph.
  • The second issue was the large size of model parameters; in the preparation step, we built a numpy memmap(np.memmap) and saved metadata that provided us with the location of each key in the memmap. This allowed us to read parameters from the memmap as needed. Following that, we use the PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution.

Citation

Please cite PyTorch-LIT if it helps your research. You can use the following BibTeX entry:

@misc{pytorch_lit,
	title = {PyTorch-LIT},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/PyTorch-LIT}},
	year = {2021}
}
You might also like...
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification

FPGA & FreeNet Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification by Zhuo Zheng, Yanfei Zhong, Ailong M

 WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

this is a lite easy to use virtual keyboard project for anyone to use
this is a lite easy to use virtual keyboard project for anyone to use

virtual_Keyboard this is a lite easy to use virtual keyboard project for anyone to use motivation I made this for this year's recruitment for RobEn AA

Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.
Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.

TFlite Ultra Fast Lane Detection Inference Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite. So

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Code & Models for 3DETR - an End-to-end transformer model for 3D object detection
Code & Models for 3DETR - an End-to-end transformer model for 3D object detection

3DETR: An End-to-End Transformer Model for 3D Object Detection PyTorch implementation and models for 3DETR. 3DETR (3D DEtection TRansformer) is a simp

Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models
Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models

Python scripts to detect faces using Python with the BlazeFace Tensorflow Lite models. Tested on Windows 10, Tensorflow 2.4.0 (Python 3.8).

A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] An end-to-end PyTorch framework for image and video classification
An end-to-end PyTorch framework for image and video classification

What's New: March 2021: Added RegNetZ models November 2020: Vision Transformers now available, with training recipes! 2020-11-20: Classy Vision v0.5 R

Comments
  • RuntimeError : OrderdDict mutated during iteration.

    RuntimeError : OrderdDict mutated during iteration.

    Hi, there are new problems. When the model parameters forward, raise a RuntimeError : OrderdDict mutated during iteration. detail as below: Traceback (most recent call last): File "nlp/rct-FPM-rhino/big_model/predict.py", line 24, in result = model(**tokens) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 34, in call return self.forward(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 31, in forward return self.module(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1057, in _call_impl for hook in itertools.chain( RuntimeError: OrderedDict mutated during iteration

    enviroments:

    GPU:NVIDIA GeForce 3090 CUDA version 11.4 pip list: certifi 2021.10.8 charset-normalizer 2.0.8 click 8.0.3 filelock 3.4.0 huggingface-hub 0.2.0 idna 3.3 joblib 1.1.0 numpy 1.21.4 packaging 21.3 Pillow 8.4.0 pip 21.2.4 pyparsing 3.0.6 pytorch-lit 0.1.7 PyYAML 6.0 regex 2021.11.10 requests 2.26.0 sacremoses 0.0.46 setuptools 58.0.4 six 1.16.0 tokenizer 3.3.2 tokenizers 0.10.3 torch 1.9.1+cu111 torchaudio 0.8.1 torchvision 0.9.1+cu111 tqdm 4.62.3 transformers 4.12.5 typing_extensions 4.0.1 urllib3 1.26.7

    I think this problem caused by PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution, when load and unload the parameters,the OrderedDict was be mutated.

    opened by changleilei 9
  • TypeError: <lambda>() missing 1 required positional argument: 'k'

    TypeError: () missing 1 required positional argument: 'k'

    Hello, when i use pytorch-lit prepare a model, got a TypeError as title. The detail as blow:

    File "nlp/rct-FPM-rhino/big_model/prepare_model.py", line 16, in prepare_model prepare_params(model, args.save_path, dtype='float32') File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 19, in prepare_params _params_to_memmap(parameters, path.join(save_dir, "model.bin"), File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 52, in _params_to_memmap param = get_param(k) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 50, in get_param = lambda key: params"get" TypeError: () missing 1 required positional argument: 'k'

    package list:

    certifi 2021.10.8 numpy 1.21.4 pip 21.2.4 pytorch-lit 0.1.6 setuptools 58.0.4 torch 1.10.0 tqdm 4.62.3 typing_extensions 4.0.1 wheel 0.37.0

    model: gpt-j-6B

    Have any suggesstion? Thanks.

    opened by changleilei 1
  • gpt-j generation speed very low

    gpt-j generation speed very low

    The output of gpt-j is very slow, for a 200 output token generation it takes about 20 minutes, for 2048 it takes more than an hour, this significantly limits any experimentation with the model.

    I checked Gpu utilization during inference which is about 1 percent or 4 percent, and gpu memory usage is below 4GB usage, my system has 8GB Gpu memory, if full Gpu is utilized it may be significantly increase the inference speed

    Are their simple hacks to speedup inference time ?

    opened by usama-ahmedkhan 3
  • Weights file format is changed, function partial_loader fails

    Weights file format is changed, function partial_loader fails

    Hi, thanks for your effort for making it easy to load and do inference from large models. I tried your code on a gpt-j model with different model file format, the weight files of the model are in several .pt files not like a single .bin file which your code function partial_loader() expects, does the code work with multiple weight file ? , how can i change it.

    opened by usama-ahmedkhan 4
Releases(0.1.7)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
TransPrompt - Towards an Automatic Transferable Prompting Framework for Few-shot Text Classification

TransPrompt This code is implement for our EMNLP 2021's paper 《TransPrompt:Towards an Automatic Transferable Prompting Framework for Few-shot Text Cla

WangJianing 23 Dec 21, 2022
Keras implementation of "One pixel attack for fooling deep neural networks" using differential evolution on Cifar10 and ImageNet

One Pixel Attack How simple is it to cause a deep neural network to misclassify an image if an attacker is only allowed to modify the color of one pix

Dan Kondratyuk 1.2k Dec 26, 2022
Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order of magnitude using coresets and data selection.

COResets and Data Subset selection Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order

decile-team 244 Jan 09, 2023
YOLOv5 detection interface - PyQt5 implementation

所有代码已上传,直接clone后,运行yolo_win.py即可开启界面。 2021/9/29:加入置信度选择 界面是在ultralytics的yolov5基础上建立的,界面使用pyqt5实现,内容较简单,娱乐而已。 功能: 模型选择 本地文件选择(视频图片均可) 开关摄像头

487 Dec 27, 2022
Prevent `CUDA error: out of memory` in just 1 line of code.

🐨 Koila Koila solves CUDA error: out of memory error painlessly. Fix it with just one line of code, and forget it. 🚀 Features 🙅 Prevents CUDA error

RenChu Wang 1.7k Jan 02, 2023
Codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense neural networks

DominoSearch This is repository for codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense n

11 Sep 10, 2022
Official code for "Distributed Deep Learning in Open Collaborations" (NeurIPS 2021)

Distributed Deep Learning in Open Collaborations This repository contains the code for the NeurIPS 2021 paper "Distributed Deep Learning in Open Colla

Yandex Research 96 Sep 15, 2022
codebase for "A Theory of the Inductive Bias and Generalization of Kernel Regression and Wide Neural Networks"

Eigenlearning This repo contains code for replicating the experiments of the paper A Theory of the Inductive Bias and Generalization of Kernel Regress

Jamie Simon 45 Dec 02, 2022
Json2Xml tool will help you convert from json COCO format to VOC xml format in Object Detection Problem.

JSON 2 XML All codes assume running from root directory. Please update the sys path at the beginning of the codes before running. Over View Json2Xml t

Nguyễn Trường Lâu 6 Aug 22, 2022
This repository contains the code for the paper "PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization"

PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization News: [2020/05/04] Added EGL rendering option for training data g

Shunsuke Saito 1.5k Jan 03, 2023
A PyTorch library for Vision Transformers

VFormer A PyTorch library for Vision Transformers Getting Started Read the contributing guidelines in CONTRIBUTING.rst to learn how to start contribut

Society for Artificial Intelligence and Deep Learning 142 Nov 28, 2022
This PyTorch package implements MoEBERT: from BERT to Mixture-of-Experts via Importance-Guided Adaptation (NAACL 2022).

MoEBERT This PyTorch package implements MoEBERT: from BERT to Mixture-of-Experts via Importance-Guided Adaptation (NAACL 2022). Installation Create an

Simiao Zuo 34 Dec 24, 2022
[ACM MM 2021] Joint Implicit Image Function for Guided Depth Super-Resolution

Joint Implicit Image Function for Guided Depth Super-Resolution This repository contains the code for: Joint Implicit Image Function for Guided Depth

hawkey 78 Dec 27, 2022
StellarGraph - Machine Learning on Graphs

StellarGraph Machine Learning Library StellarGraph is a Python library for machine learning on graphs and networks. Table of Contents Introduction Get

S T E L L A R 2.6k Jan 05, 2023
Image-to-image translation with conditional adversarial nets

pix2pix Project | Arxiv | PyTorch Torch implementation for learning a mapping from input images to output images, for example: Image-to-Image Translat

Phillip Isola 9.3k Jan 08, 2023
My coursework for Machine Learning (2021 Spring) at National Taiwan University (NTU)

Machine Learning 2021 Machine Learning (NTU EE 5184, Spring 2021) Instructor: Hung-yi Lee Course Website : (https://speech.ee.ntu.edu.tw/~hylee/ml/202

100 Dec 26, 2022
Neural Oblivious Decision Ensembles

Neural Oblivious Decision Ensembles A supplementary code for anonymous ICLR 2020 submission. What does it do? It learns deep ensembles of oblivious di

25 Sep 21, 2022
Neural Architecture Search Powered by Swarm Intelligence 🐜

Neural Architecture Search Powered by Swarm Intelligence 🐜 DeepSwarm DeepSwarm is an open-source library which uses Ant Colony Optimization to tackle

288 Oct 28, 2022
Use MATLAB to simulate the signal and extract features. Use PyTorch to build and train deep network to do spectrum sensing.

Deep-Learning-based-Spectrum-Sensing Use MATLAB to simulate the signal and extract features. Use PyTorch to build and train deep network to do spectru

10 Dec 14, 2022
This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network.

GPRGNN This is the source code for our ICLR2021 paper: Adaptive Universal Generalized PageRank Graph Neural Network. Hidden state feature extraction i

Jianhao 92 Jan 03, 2023