Convert scikit-learn models to PyTorch modules

Related tags

Deep Learningsk2torch
Overview

sk2torch

sk2torch converts scikit-learn models into PyTorch modules that can be tuned with backpropagation and even compiled as TorchScript.

Problems solved by this project:

  1. scikit-learn cannot perform inference on a GPU. Models like SVMs have a lot to gain from fast GPU primitives, and converting the models to PyTorch gives immediate access to these primitives.
  2. While scikit-learn supports serialization through pickle, saved models are not reproducible across versions of the library. On the other hand, TorchScript provides a convenient, safe way to save a model with its corresponding implementation. The resulting models can be loaded anywhere that PyTorch is installed, even without importing sk2torch.
  3. While certain models like SVMs and linear classifiers are theoretically end-to-end differentiable, scikit-learn provides no mechanism to compute gradients through trained models. PyTorch provides this functionality mostly for free.

See Usage for a high-level example of using the library. See How it works to see which modules are supported.

For fun, here's a vector field produced by differentiating the probability predictions of a two-class SVM (produced by this script):

A vector field quiver plot with two modes

Usage

First, train a model with scikit-learn as usual:

from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

x, y = create_some_dataset()
model = Pipeline([
    ("center", StandardScaler(with_std=False)),
    ("classify", SGDClassifier()),
])
model.fit(x, y)

Then call sk2torch.wrap on the model to create a PyTorch equivalent:

import sk2torch
import torch

torch_model = sk2torch.wrap(model)
print(torch_model.predict(torch.tensor([[1., 2., 3.]]).double()))

You can save a model with TorchScript:

import torch.jit

torch.jit.script(torch_model).save("path.pt")

# ... sk2torch need not be installed to load the model.
loaded_model = torch.jit.load("path.pt")

For a full example of training a model and using its PyTorch translation, see examples/svm_vector_field.py.

How it works

sk2torch contains PyTorch re-implementations of supported scikit-learn models. For a supported estimator X, a class TorchX in sk2torch will be able to read the attributes of X and convert them to torch.Tensor or simple Python types. TorchX subclasses torch.nn.Module and has a method for each inference API of X (e.g. predict, decision_function, etc.).

Which modules are supported? The easiest way to get an up-to-date list is via the supported_classes() function, which returns all wrap()able scikit-learn classes:

>>> import sk2torch
>>> sk2torch.supported_classes()
[<class 'sklearn.tree._classes.DecisionTreeClassifier'>, <class 'sklearn.tree._classes.DecisionTreeRegressor'>, <class 'sklearn.dummy.DummyClassifier'>, <class 'sklearn.ensemble._gb.GradientBoostingClassifier'>, <class 'sklearn.preprocessing._label.LabelBinarizer'>, <class 'sklearn.svm._classes.LinearSVC'>, <class 'sklearn.svm._classes.LinearSVR'>, <class 'sklearn.neural_network._multilayer_perceptron.MLPClassifier'>, <class 'sklearn.kernel_approximation.Nystroem'>, <class 'sklearn.pipeline.Pipeline'>, <class 'sklearn.linear_model._stochastic_gradient.SGDClassifier'>, <class 'sklearn.preprocessing._data.StandardScaler'>, <class 'sklearn.svm._classes.SVC'>, <class 'sklearn.svm._classes.NuSVC'>, <class 'sklearn.svm._classes.SVR'>, <class 'sklearn.svm._classes.NuSVR'>, <class 'sklearn.compose._target.TransformedTargetRegressor'>]

Comparison to sklearn-onnx

sklearn-onnx is an open source package for converting trained scikit-learn models into ONNX. Like sk2torch, sklearn-onnx re-implements inference functions for various models, meaning that it can also provide serialization and GPU acceleration for supported modules.

Naturally, neither library will support modules that aren't manually ported. As a result, the two libraries support different subsets of all available models/methods. For example, sk2torch supports the SVC probability prediction methods predict_proba and predict_log_prob, whereas sklearn-onnx does not.

While sklearn-onnx exports models to ONNX, sk2torch exports models to Python objects with familiar method names that can be fine-tuned, backpropagated through, and serialized in a user-friendly way. PyTorch is strictly more general than ONNX, since PyTorch models can be converted to ONNX if desired.

Owner
Alex Nichol
Web developer, math geek, and AI enthusiast.
Alex Nichol
[CVPR22] Official codebase of Semantic Segmentation by Early Region Proxy.

RegionProxy Figure 2. Performance vs. GFLOPs on ADE20K val split. Semantic Segmentation by Early Region Proxy Yifan Zhang, Bo Pang, Cewu Lu CVPR 2022

Yifan 54 Nov 29, 2022
Solving SMPL/MANO parameters from keypoint coordinates.

Minimal-IK A simple and naive inverse kinematics solver for MANO hand model, SMPL body model, and SMPL-H body+hand model. Briefly, given joint coordin

Yuxiao Zhou 305 Dec 30, 2022
Using the provided dataset which includes various book features, in order to predict the price of books, using various proposed methods and models.

Using the provided dataset which includes various book features, in order to predict the price of books, using various proposed methods and models.

Nikolas Petrou 1 Jan 13, 2022
Implementation of our NeurIPS 2021 paper "A Bi-Level Framework for Learning to Solve Combinatorial Optimization on Graphs".

PPO-BiHyb This is the official implementation of our NeurIPS 2021 paper "A Bi-Level Framework for Learning to Solve Combinatorial Optimization on Grap

<a href=[email protected]"> 66 Nov 23, 2022
PyTorch Lightning + Hydra. A feature-rich template for rapid, scalable and reproducible ML experimentation with best practices. ⚡🔥⚡

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

Łukasz Zalewski 2.1k Jan 09, 2023
Implementation of Hierarchical Transformer Memory (HTM) for Pytorch

Hierarchical Transformer Memory (HTM) - Pytorch Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a si

Phil Wang 63 Dec 29, 2022
Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localization and Semantic Segmentation (CVPR 2022)

CCAM (Unsupervised) Code repository for our paper "CCAM: Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localizati

Computer Vision Insitute, SZU 113 Dec 27, 2022
这是一个利用facenet和retinaface实现人脸识别的库,可以进行在线的人脸识别。

Facenet+Retinaface:人脸识别模型在Keras当中的实现 目录 注意事项 Attention 所需环境 Environment 文件下载 Download 预测步骤 How2predict 参考资料 Reference 注意事项 该库中包含了两个网络,分别是retinaface和fa

Bubbliiiing 31 Nov 15, 2022
Github project for Attention-guided Temporal Coherent Video Object Matting.

Attention-guided Temporal Coherent Video Object Matting This is the Github project for our paper Attention-guided Temporal Coherent Video Object Matti

71 Dec 19, 2022
Small utility to demangle Nim symbols in callgrind files

nim_callgrind A small utility to demangle Nim symbols from callgrind files. Usage Run your (Nim) program with something like this: valgrind --tool=cal

kraptor 3 Feb 15, 2022
Official Pytorch implementation of RePOSE (ICCV2021)

RePOSE: Iterative Rendering and Refinement for 6D Object Detection (ICCV2021) [Link] Abstract We present RePOSE, a fast iterative refinement method fo

Shun Iwase 68 Nov 15, 2022
Repository for GNSS-based position estimation using a Deep Neural Network

Code repository accompanying our work on 'Improving GNSS Positioning using Neural Network-based Corrections'. In this paper, we present a Deep Neural

32 Dec 13, 2022
PyTorch implementation of EGVSR: Efficcient & Generic Video Super-Resolution (VSR)

This is a PyTorch implementation of EGVSR: Efficcient & Generic Video Super-Resolution (VSR), using subpixel convolution to optimize the inference speed of TecoGAN VSR model. Please refer to the offi

789 Jan 04, 2023
Semantic segmentation task for ADE20k & cityscapse dataset, based on several models.

semantic-segmentation-tensorflow This is a Tensorflow implementation of semantic segmentation models on MIT ADE20K scene parsing dataset and Cityscape

HsuanKung Yang 83 Oct 13, 2022
Python framework for Stochastic Differential Equations modeling

SDElearn: a Python package for SDE modeling This package implements functionalities for working with Stochastic Differential Equations models (SDEs fo

4 May 10, 2022
🔊 Audio and fastai v2

Fastaudio An audio module for fastai v2. We want to help you build audio machine learning applications while minimizing the need for audio domain expe

152 Dec 28, 2022
DenseNet Implementation in Keras with ImageNet Pretrained Models

DenseNet-Keras with ImageNet Pretrained Models This is an Keras implementation of DenseNet with ImageNet pretrained weights. The weights are converted

Felix Yu 568 Oct 31, 2022
Fully convolutional networks for semantic segmentation

FCN-semantic-segmentation Simple end-to-end semantic segmentation using fully convolutional networks [1]. Takes a pretrained 34-layer ResNet [2], remo

Kai Arulkumaran 186 Dec 25, 2022
3D Multi-Person Pose Estimation by Integrating Top-Down and Bottom-Up Networks

3D Multi-Person Pose Estimation by Integrating Top-Down and Bottom-Up Networks Introduction This repository contains the code and models for the follo

124 Jan 06, 2023
The repo of the preprinting paper "Labels Are Not Perfect: Inferring Spatial Uncertainty in Object Detection"

Inferring Spatial Uncertainty in Object Detection A teaser version of the code for the paper Labels Are Not Perfect: Inferring Spatial Uncertainty in

ZINING WANG 21 Mar 03, 2022