Fastshap: A fast, approximate shap kernel

Related tags

Deep Learningfastshap
Overview

fastshap: A fast, approximate shap kernel

fastshap was designed to be:

  • Fast Calculating shap values can take an extremely long time. fastshap utilizes inner and outer batch assignments to keep the calculations inside vectorized operations as often as it can.
  • Used on Tabular Data Can accept numpy arrays or pandas DataFrames, and can handle categorical variables natively. As of right now, only 1 dimensional outputs are accepted.

WARNING This package specifically offers a kernel explainer, which can calculate approximate shap values of f(X) towards y for any function f. Much faster shap solutions are available specifically for gradient boosted trees.

Installation

This package can be installed using either pip or conda, through conda-forge:

# Using pip
$ pip install fastshap --no-cache-dir

You can also download the latest development version from this repository. If you want to install from github with conda, you must first run conda install pip git.

$ pip install git+https://github.com/AnotherSamWilson/fastshap.git

Basic Usage

We will use the iris dataset for this example. Here, we load the data and train a simple lightgbm model on the dataset:

from sklearn.datasets import load_iris
import pandas as pd
import lightgbm as lgb
import numpy as np

# Define our dataset and target variable
data = pd.concat(load_iris(as_frame=True,return_X_y=True),axis=1)
data.rename({"target": "species"}, inplace=True, axis=1)
data["species"] = data["species"].astype("category")
target = data.pop("sepal length (cm)")

# Train our model
dtrain = lgb.Dataset(data=data, label=target)
lgbmodel = lgb.train(
    params={"seed": 1, "verbose": -1},
    train_set=dtrain,
    num_boost_round=10
)

# Define the function we wish to build shap values for.
model = lgbmodel.predict

preds = model(data)

We now have a model which takes a Pandas dataframe, and returns predictions. We can create an explainer that will use data as a background dataset to calculate the shap values of any dataset we wish:

import fastshap

ke = fastshap.KernelExplainer(model, data)
sv = ke.calculate_shap_values(data, verbose=False)

print(all(preds == sv.sum(1)))
## True

Stratifying the Background Set

We can select a subset of our data to act as a background set. By stratifying the background set on the results of the model output, we will usually get very similar results, while decreasing the caculation time drastically.

ke.stratify_background_set(5)
sv2 = ke.calculate_shap_values(
  data, 
  background_fold_to_use=0,
  verbose=False
)

print(np.abs(sv2 - sv).mean(0))
## [1.74764532e-03 1.61829094e-02 1.99534408e-03 4.02640884e-16
##  1.71084747e-02]

What we did is break up our background set into 10 different sets, stratified by the model output. We then used the first of these sets as our background set. We then compared the average difference between these shap values, and the shap values we obtained from using the entire dataset.

Choosing Batch Sizes

If the entire process was vectorized, it would require an array of size (# Samples * # Coalitions * # Background samples, # Columns). Where # Coalitions is the sum of the total number of coalitions that are going to be run. Even for small datasets, this becomes enormous. fastshap breaks this array up into chunks by splitting the process into a series of batches.

This is a list of the large arrays and their maximum size:

  • Global
    • Mask Matrix (# Coalitions, # Columns) dtype = int8
  • Outer Batch
    • Linear Targets (Total Coalition Combinations, Outer Batch Size) dtype = adaptive
  • Inner Batch
    • Model Evaluation Features (Inner Batch Size, # background samples) dtype = adaptive

The adaptive datatypes of the arrays above will be matched to the data types of the model output. Therefore, if your model returns float32, these arrays will be stored as float32. The final, returned shap values will also be returned as the datatype returned by the model.

These theoretical sizes can be calculated directly so that the user can determine appropriate batch sizes for their machine:

# Combines our background data back into 1 DataFrame
ke.stratify_background_set(1)
(
    mask_matrix_size, 
    linear_target_size, 
    inner_model_eval_set_size
) = ke.get_theoretical_array_expansion_sizes(
    outer_batch_size=150,
    inner_batch_size=150,
    n_coalition_sizes=3,
    background_fold_to_use=None,
)

print(
  np.product(linear_target_size) + np.product(inner_model_eval_set_size)
)
## 92100

For the iris dataset, even if we sent the entire set (150 rows) through as one batch, we only need 92100 elements stored in arrays. This is manageable on most machines. However, this number grows extremely quickly with the samples and number of columns. It is highly advised to determine a good batch scheme before running this process.

Specifying a Custom Linear Model

Any linear model available from sklearn.linear_model can be used to calculate the shap values. If you wish for some sparsity in the shap values, you can use Lasso regression:

from sklearn.linear_model import Lasso

# Use our entire background set
ke.stratify_background_set(1)
sv_lasso = ke.calculate_shap_values(
  data, 
  background_fold_to_use=0,
  linear_model=Lasso(alpha=0.1),
  verbose=False
)

print(sv_lasso[0,:])
## [-0.         -0.33797832 -0.         -0.14634971  5.84333333]

The default model used is sklearn.linear_model.LinearRegression.

Owner
Samuel Wilson
Samuel Wilson
Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021, Pytorch)

S2VD Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021) Requirements and Dependencies Ubuntu 16.04, cuda 10.0 Python 3.6.10, P

Zongsheng Yue 53 Nov 23, 2022
A package related to building quasi-fibration symmetries

qf A package related to building quasi-fibration symmetries. If you'd like to learn more about how it works, see the brief explanation and References

Paolo Boldi 1 Dec 01, 2021
Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations, CVPR 2019 (Oral)

Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations The code of: Weakly Supervised Learning of Instance Segmentation with I

Jiwoon Ahn 472 Dec 29, 2022
An Unpaired Sketch-to-Photo Translation Model

Unpaired-Sketch-to-Photo-Translation We have released our code at https://github.com/rt219/Unsupervised-Sketch-to-Photo-Synthesis This project is the

38 Oct 28, 2022
[ICSE2020] MemLock: Memory Usage Guided Fuzzing

MemLock: Memory Usage Guided Fuzzing This repository provides the tool and the evaluation subjects for the paper "MemLock: Memory Usage Guided Fuzzing

Cheng Wen 54 Jan 07, 2023
A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

Ayushman Dash 93 Aug 04, 2022
The aim of this project is to build an AI bot that can play the Wordle game, or more generally Squabble

Wordle RL The aim of this project is to build an AI bot that can play the Wordle game, or more generally Squabble I know there are more deterministic

Aditya Arora 3 Feb 22, 2022
Code for EMNLP'21 paper "Types of Out-of-Distribution Texts and How to Detect Them"

ood-text-emnlp Code for EMNLP'21 paper "Types of Out-of-Distribution Texts and How to Detect Them" Files fine_tune.py is used to finetune the GPT-2 mo

Udit Arora 19 Oct 28, 2022
Research code of ICCV 2021 paper "Mesh Graphormer"

MeshGraphormer ✨ ✨ This is our research code of Mesh Graphormer. Mesh Graphormer is a new transformer-based method for human pose and mesh reconsructi

Microsoft 251 Jan 08, 2023
Demo code for paper "Learning optical flow from still images", CVPR 2021.

Depthstillation Demo code for "Learning optical flow from still images", CVPR 2021. [Project page] - [Paper] - [Supplementary] This code is provided t

130 Dec 25, 2022
Model parallel transformers in Jax and Haiku

Mesh Transformer Jax A haiku library using the new(ly documented) xmap operator in Jax for model parallelism of transformers. See enwik8_example.py fo

Ben Wang 4.8k Jan 01, 2023
[CoRL 2021] A robotics benchmark for cross-embodiment imitation.

x-magical x-magical is a benchmark extension of MAGICAL specifically geared towards cross-embodiment imitation. The tasks still provide the Demo/Test

Kevin Zakka 36 Nov 26, 2022
(3DV 2021 Oral) Filtering by Cluster Consistency for Large-Scale Multi-Image Matching

Scalable Cluster-Consistency Statistics for Robust Multi-Object Matching (3DV 2021 Oral Presentation) Filtering by Cluster Consistency (FCC) is a very

Yunpeng Shi 11 Sep 28, 2022
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Equinox Callable PyTrees and filtered JIT/grad transformations = neural networks in JAX Equinox brings more power to your model building in JAX. Repr

Patrick Kidger 909 Dec 30, 2022
From Fidelity to Perceptual Quality: A Semi-Supervised Approach for Low-Light Image Enhancement (CVPR'2020)

Under-exposure introduces a series of visual degradation, i.e. decreased visibility, intensive noise, and biased color, etc. To address these problems, we propose a novel semi-supervised learning app

Yang Wenhan 117 Jan 03, 2023
PyTorch implementation of our ICCV 2019 paper: Liquid Warping GAN: A Unified Framework for Human Motion Imitation, Appearance Transfer and Novel View Synthesis

Impersonator PyTorch implementation of our ICCV 2019 paper: Liquid Warping GAN: A Unified Framework for Human Motion Imitation, Appearance Transfer an

SVIP Lab 1.7k Jan 06, 2023
StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Demo video: CVPR 2021 Oral: Single Channel Manipulation: Localized or attribu

Zongze Wu 267 Dec 30, 2022
Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks

Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks This is the official code for DyReg model inroduced in Discovering Dyna

Bitdefender Machine Learning 11 Nov 08, 2022
[CVPR 2022] PoseTriplet: Co-evolving 3D Human Pose Estimation, Imitation, and Hallucination under Self-supervision (Oral)

PoseTriplet: Co-evolving 3D Human Pose Estimation, Imitation, and Hallucination under Self-supervision Kehong Gong*, Bingbing Li*, Jianfeng Zhang*, Ta

256 Dec 28, 2022
KE-Dialogue: Injecting knowledge graph into a fully end-to-end dialogue system.

Learning Knowledge Bases with Parameters for Task-Oriented Dialogue Systems This is the implementation of the paper: Learning Knowledge Bases with Par

CAiRE 42 Nov 10, 2022