TimeSHAP explains Recurrent Neural Network predictions.

Overview

TimeSHAP

TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes event/timestamp- feature-, and cell-level attributions. As sequences can be arbitrarily long, TimeSHAP also implements a pruning algorithm based on Shapley Values, that finds a subset of consecutive, recent events that contribute the most to the decision.

This repository is the code implementation of the TimeSHAP algorithm present in the paper TimeSHAP: Explaining Recurrent Models through Sequence Perturbations published at KDD 2021.

Links to the paper here, and to the video presentation here.

Install TimeSHAP

Clone the repository into a local directory using:

git clone https://github.com/feedzai/timeshap.git

Install the package using pip:

pip install timeshap

To test your installation, start a Python session in your terminal using

python

And import TimeSHAP

import timeshap

TimeSHAP in 30 seconds

Inputs

  • Model being explained;
  • Instance(s) to explain;
  • Background instance.

Outputs

  • Local pruning output; (explaining a single instance)
  • Local event explanations; (explaining a single instance)
  • Local feature explanations; (explaining a single instance)
  • Global pruning statistics; (explaining multiple instances)
  • Global event explanations; (explaining multiple instances)
  • Global feature explanations; (explaining multiple instances)

Model Interface

In order for TimeSHAP to explain a model, an entry point must be provided. This Callable entry point must receive a 3-D numpy array, (#sequences; #sequence length; #features) and return a 2-D numpy array (#sequences; 1) with the corresponding score of each sequence. In addition, to make TimeSHAP more optimized, it is possible to return the hidden state of the model together with the score (if applicable), although this is optional.

TimeSHAP is able to explain any black-box model as long as it complies with the previously described interface, including both PyTorch and TensorFlow models, both examplified in our tutorials (PyTorch, TensorFlow).

Example provided in our tutorials:

  • TensorFLow
model = tf.keras.models.Model(inputs=inputs, outputs=ff2)
f = lambda x: model.predict(x)
  • Pytorch - (Example where model receives and returns hidden states)
model_wrapped = TorchModelWrapper(model)
f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)
Model Wrappers

In order to facilitate the interface between models and TimeSHAP, TimeSHAP implements ModelWrappers. These wrappers, used on the PyTorch tutorial notebook, allow for greater flexibility of explained models as they allow:

  • Batching logic: useful when using very large inputs or NSamples, which cannot fit on GPU memory, and therefore batching mechanisms are required;
  • Input format/type: useful when your model does not work with numpy arrays. This is the case of our provided PyToch example;

TimeSHAP Explanation Methods

TimeSHAP offers several methods to use depending on the desired explanations. Local methods provide detailed view of a model decision corresponding to a specific sequence being explained. Global methods aggregate local explanations of a given dataset to present a global view of the model.

Local Explanations

Pruning

local_pruning() performs the pruning algorithm on a given sequence with a given user defined tolerance and returns the pruning index along the information for plotting.

plot_temp_coalition_pruning() plots the pruning algorithm information calculated by local_pruning().

Event level explanations

local_event() calculates event level explanations of a given sequence with the user-given parameteres and returns the respective event-level explanations.

plot_event_heatmap() plots the event-level explanations calculated by local_event().

Feature level explanations

local_feat() calculates feature level explanations of a given sequence with the user-given parameteres and returns the respective feature-level explanations.

plot_feat_barplot() plots the feature-level explanations calculated by local_feat().

Cell level explanations

local_cell_level() calculates cell level explanations of a given sequence with the respective event- and feature-level explanations and user-given parameteres, returing the respective cell-level explanations.

plot_cell_level() plots the feature-level explanations calculated by local_cell_level().

Local Report

local_report() calculates TimeSHAP local explanations for a given sequence and plots them.

Global Explanations

Global pruning statistics

prune_all() performs the pruning algorithm on multiple given sequences.

pruning_statistics() calculates the pruning statistics for several user-given pruning tolerances using the pruning data calculated by prune_all(), returning a pandas.DataFrame with the statistics.

Global event level explanations

event_explain_all() calculates TimeSHAP event level explanations for multiple instances given user defined parameters.

plot_global_event() plots the global event-level explanations calculated by event_explain_all().

Global feature level explanations

feat_explain_all() calculates TimeSHAP feature level explanations for multiple instances given user defined parameters.

plot_global_feat() plots the global feature-level explanations calculated by feat_explain_all().

Global report

global_report() calculates TimeSHAP explanations for multiple instances, aggregating the explanations on two plots and returning them.

Tutorial

In order to demonstrate TimeSHAP interfaces and methods, you can consult AReM.ipynb. In this tutorial we get an open-source dataset, process it, train Pytorch recurrent model with it and use TimeSHAP to explain it, showcasing all previously described methods.

Additionally, we also train a TensorFlow model on the same dataset AReM_TF.ipynb.

Repository Structure

Citing TimeSHAP

@inproceedings{bento2021timeshap,
    author = {Bento, Jo\~{a}o and Saleiro, Pedro and Cruz, Andr\'{e} F. and Figueiredo, M\'{a}rio A.T. and Bizarro, Pedro},
    title = {TimeSHAP: Explaining Recurrent Models through Sequence Perturbations},
    year = {2021},
    isbn = {9781450383325},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    url = {https://doi.org/10.1145/3447548.3467166},
    doi = {10.1145/3447548.3467166},
    booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining},
    pages = {2565–2573},
    numpages = {9},
    keywords = {SHAP, Shapley values, TimeSHAP, XAI, RNN, explainability},
    location = {Virtual Event, Singapore},
    series = {KDD '21}
}
Comments
  • Error in running the example notebook (AReM_TF)

    Error in running the example notebook (AReM_TF)

    Nice work! I have been trying to run one of the tutorial notebooks in the repository (i.e., AReM_TF), but I faced an error. The notebook chunk that produces the error is:

    from timeshap.explainer import local_report, local_pruning
    
    pruning_dict = {'tol': 0.025}
    event_dict = {'rs': 42, 'nsamples': 320}
    feature_dict = {'rs': 42, 'nsamples': 320, 'feature_names': model_features, 'plot_features': plot_feats}
    cell_dict = {'rs': 42, 'nsamples': 320, 'top_x_feats': 2, 'top_x_events': 2}
    local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict,cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    

    and the produced error is as follows: image image

    It would be great if you could help me with this error.

    opened by aminnayebi 7
  • Unable to install timeshap package

    Unable to install timeshap package

    Hello @feedzaiadmin , @saleiro ,

    I am unable to run command "pip install timeshap" on my ubuntu system. It throws me below error:

    ERROR: Could not find a version that satisfies the requirement timeshap (from versions: none) ERROR: No matching distribution found for timeshap

    Seems to me a compatibility issue. I have tried with python versions - 3.6,3.8. Is there any specific version which supports it? Will be waiting for response.

    Thanks

    opened by vishants98 5
  • How to adapt the transformation function to account for variable sequence length?

    How to adapt the transformation function to account for variable sequence length?

    I am trying to use TimeSHAP on my use case. Per my understanding, in AReM example, the way you transform the data using the df_to_numpy function is to make a prediction for the last value of the sequence – see the screen below:

    image In the case of AReM tutorial data, the predictions are based on the whole sequence - all rows (rows ID 1-10) are being used for sequence ID 1 (light blue color) and the predictions are made for the Timestamp 10 (dark blue color; rows id 10). Later the light orange color is used (Row IDs 11-20) to predict a label marked as dark orange color (Row ID 20).

    In the case of my use case, the model predicts on a rolling-window basis and I would need predictions for every row (not only for a sequence). See the screen and explanation below. image Let's say my rolling window is 6 and Row IDs 1-6 (light green) are used to predict row 7 (dark green), later Row IDs 2-7 (light grey) are being used to predict Row ID 8 (dark grey), etc. When a new Sequence starts, we repeat the process, so we take Row IDs 11-16 and predict Row ID 17, etc. For my use case, it's important to evaluate the predictions for every Row ID, not only for the whole sequence.

    The problem which I am facing is that when I try to run the function get_avg_score_with_avg_event on the data defined as in the picture above I am getting the following error: image

    The way my data is transformed from 2D into 3D format is defined by the function below: image

    My question is whether it’s possible to make TimeSHAP work for the data which is transformed in a way described in my use case? When I use the transformation which is defined in your function df_to_numpy, I am not getting an error, however, it is not adapted to my use case.

    opened by grzechowiak 4
  • Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Foremost, thanks for the library, really great job!

    I am trying to reproduce your TimeSHAP Tutorial for TF and I am having an issue in the section for Global Explanations when running the global_report() function - screen below:

    image

    The error refers to encoding the \u2264 character which is a sign: <=. I was trying to solve that myself by modifying the pruning.py function according to the error by adding encoding="utf-8" in line 326 with open(file_path, 'a', newline='') as file:, however it didn't solve the problem. Any advice is very welcome!

    Also, for consistency, I want to mention that I had a problem loading the data - showing the error screens below. I was able to solve the problem only by deleting 2 datasets: cycling/dataset9.csv and cycling/dataset14csv and the rest of the code worked.

    1/2 image 2/2 image

    opened by grzechowiak 4
  • Timeshap for regression

    Timeshap for regression

    I am working on a time series forecasting problem using LSTM. Can I use timeshap for such a regression problem? Do you by chance have a demo for regression?

    Thanks

    opened by mgorjis 3
  • Plot Coalition Pruning is not working

    Plot Coalition Pruning is not working

    Hi, I am doing a project as part of a Master Thesis: I was testing your introductory notebook with the Tensorflow implementation, but this is not working because of an error raised by one of yours plot functions. The funny thing is that your other notebook, i.e., the one with the torch model, is working fine. I managed to find the issue and it was raised by an inner method called solve_negatives_method in src/timeshap/plot/pruning.py which would raise an InvalidKeyError caused by this specific row:

        df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    As far I understood, we are passing a list, usually made of only one value, to the pandas.DataFrame.at method which requires to pass an integer parameter and a column identifier. As such, one solution that works now could be:

        df.at[corresponding_row.index[0], 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    I want to thank you very much for the work done, and I will be happy to discuss further result of my thesis with you. :) Hoping for the best and looking forward to hear from you, Eric

    opened by Erhtric 2
  • TimeSHAP for text?

    TimeSHAP for text?

    I'm working with a 1-layer GRU for text classification that takes BERT embeddings at the input. Each input sequence is of the shape (sequence length, bert-embedding-dimension). I'm looking for word level attribution scores for each sequence's prediction. Currently with the captum integrated gradients and occlusion explainers, I get attribution scores that are almost always the last few words of each sequence. This seems like it's stemming from the directional processing of GRU - any thoughts?

    Do you think TimeSHAP would be applicable for my use case? I suppose I could consider each word as an event and each embedding dimension as a feature, then I could use the event level local explanations from the library? However, note that in my case, the events (i.e words) from the beginning of the sequence could be more important than those at the end of the sequence (i.e. most recent ones) - this violates the assumptions you use for your approximation (i.e pruning), so perhaps it's not applicable to text?

    opened by itsmemala 2
  • Intuition around pruning/baseline selection

    Intuition around pruning/baseline selection

    While calculating a global report, I'm currently running into errors that "Score difference between baseline and instance is too low < 0.1...Consider choosing another baseline." My baselines have been the average_event and average_sequence. I also notice that occasionally the values in the error change with different pruning tolerance, but not in a consistent way. Do you have advice for dealing with this? Thanks.

    opened by xydisla 2
  • CNN model

    CNN model

    I currently have a CNN model, and previously had to do some strange hacking to get time series importance values. Your package now shines a better light on this issue. For multivariate forecasts CNNs still fair well, sometimes better than RNNs, from what I can see there is no reason why using CNNs won't working using your software. Let me know if I have that wrong.

    opened by firmai 2
  • How to speed up TIMESHAP computation

    How to speed up TIMESHAP computation

    Hi all!

    The package itself is really interesting and intuitive to use. But I want to speed up TIMESHAP computation, Can I use gpu to calculate shapley values? I used a device parameter in TorchModelWrapper, but efficiency of GPU is too low to accelerate TIMESHAP computation. Any suggestion would be appreciated.

    opened by Changshu135 2
  • Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Hi all! When executing your TF example notebook without any changes, it fails at this line: local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)

    with the following error: InvalidIndexError: Int64Index([1], dtype='int64')

    Stack trace:

    TypeError                                 Traceback (most recent call last)
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3621, in Index.get_loc(self, key, method, tolerance)
       3620 try:
    -> 3621     return self._engine.get_loc(casted_key)
       3622 except KeyError as err:
    
    File pandas/_libs/index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()
    
    File pandas/_libs/index.pyx:142, in pandas._libs.index.IndexEngine.get_loc()
    
    TypeError: 'Int64Index([1], dtype='int64')' is an invalid key
    
    During handling of the above exception, another exception occurred:
    
    InvalidIndexError                         Traceback (most recent call last)
    Input In [27], in <cell line: 7>()
          5 feature_dict = {'rs': 42, 'nsamples': 32000, 'feature_names': model_features, 'plot_features': plot_feats}
          6 cell_dict = {'rs': 42, 'nsamples': 32000, 'top_x_feats': 2, 'top_x_events': 2}
    ----> 7 local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    
    File ~/temp/timeshap/src/timeshap/explainer/local_methods.py:139, in local_report(f, data, pruning_dict, event_dict, feature_dict, cell_dict, entity_uuid, entity_col, time_col, model_features, baseline, verbose)
        137 pruning_idx = data.shape[1] + coal_prun_idx
        138 plot_lim = max(abs(coal_prun_idx)+10, 40)
    --> 139 pruning_plot = plot_temp_coalition_pruning(coal_plot_data, coal_prun_idx, plot_lim)
        141 event_data = local_event(f, data, event_dict, entity_uuid, entity_col, baseline, pruning_idx)
        142 event_plot = plot_event_heatmap(event_data)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:53, in plot_temp_coalition_pruning(df, pruned_idx, plot_limit, solve_negatives)
         51 df = df[df['t (event index)'] >= -plot_limit]
         52 if solve_negatives:
    ---> 53     df = solve_negatives_method(df)
         55 base = (alt.Chart(df).encode(
         56     x=alt.X("t (event index)", axis=alt.Axis(title='t (event index)', labelFontSize=15,
         57                           titleFontSize=15)),
       (...)
         70 )
         71 )
         73 area_chart = base.mark_area(opacity=0.5)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:47, in plot_temp_coalition_pruning.<locals>.solve_negatives_method(df)
         45 for idx, row in negative_values.iterrows():
         46     corresponding_row = df[np.logical_and(df['t (event index)'] == row['t (event index)'], ~(df['Coalition'] == row['Coalition']))]
    ---> 47     df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
         48     df.at[idx, 'Shapley Value'] = 0
         49 return df
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2281, in _AtIndexer.__setitem__(self, key, value)
       2278     self.obj.loc[key] = value
       2279     return
    -> 2281 return super().__setitem__(key, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2236, in _ScalarAccessIndexer.__setitem__(self, key, value)
       2233 if len(key) != self.ndim:
       2234     raise ValueError("Not enough indexers for scalar access (setting)!")
    -> 2236 self.obj._set_value(*key, value=value, takeable=self._takeable)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/frame.py:3869, in DataFrame._set_value(self, index, col, value, takeable)
       3867 else:
       3868     series = self._get_item_cache(col)
    -> 3869     loc = self.index.get_loc(index)
       3871 # setitem_inplace will do validation that may raise TypeError
       3872 #  or ValueError
       3873 series._mgr.setitem_inplace(loc, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3628, in Index.get_loc(self, key, method, tolerance)
       3623         raise KeyError(key) from err
       3624     except TypeError:
       3625         # If we have a listlike key, _check_indexing_error will raise
       3626         #  InvalidIndexError. Otherwise we fall through and re-raise
       3627         #  the TypeError.
    -> 3628         self._check_indexing_error(key)
       3629         raise
       3631 # GH#42269
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:5637, in Index._check_indexing_error(self, key)
       5633 def _check_indexing_error(self, key):
       5634     if not is_scalar(key):
       5635         # if key is not a scalar, directly raise an error (the code below
       5636         # would convert to numpy arrays and raise later any way) - GH29926
    -> 5637         raise InvalidIndexError(key)
    
    InvalidIndexError: Int64Index([1], dtype='int64')
    
    opened by JulianKlug 2
Owner
Feedzai
Feedzai
GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification

GalaXC GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification @InProceedings{Saini21, author = {Saini, D. and Jain,

Extreme Classification 28 Dec 05, 2022
Algorithmic Trading using RNN

Deep-Trading This an implementation adapted from Rachnog Neural networks for algorithmic trading. Part One — Simple time series forecasting and this c

Hazem Nomer 29 Sep 04, 2022
The codebase for our paper "Generative Occupancy Fields for 3D Surface-Aware Image Synthesis" (NeurIPS 2021)

Generative Occupancy Fields for 3D Surface-Aware Image Synthesis (NeurIPS 2021) Project Page | Paper Xudong Xu, Xingang Pan, Dahua Lin and Bo Dai GOF

xuxudong 97 Nov 10, 2022
A hifiasm fork for metagenome assembly using Hifi reads.

hifiasm_meta - de novo metagenome assembler, based on hifiasm, a haplotype-resolved de novo assembler for PacBio Hifi reads.

44 Jul 10, 2022
A smaller subset of 10 easily classified classes from Imagenet, and a little more French

Imagenette 🎶 Imagenette, gentille imagenette, Imagenette, je te plumerai. 🎶 (Imagenette theme song thanks to Samuel Finlayson) NB: Versions of Image

fast.ai 718 Jan 01, 2023
Flickr-Faces-HQ (FFHQ) is a high-quality image dataset of human faces, originally created as a benchmark for generative adversarial networks (GAN)

Flickr-Faces-HQ Dataset (FFHQ) Flickr-Faces-HQ (FFHQ) is a high-quality image dataset of human faces, originally created as a benchmark for generative

NVIDIA Research Projects 2.9k Dec 28, 2022
A simple Neural Network that predicts the label for a series of handwritten digits

Neural_Network A simple Neural Network that predicts the label for a series of handwritten numbers This program tries to predict the label (1,2,3 etc.

Ty 1 Dec 18, 2021
Official PyTorch code for the paper: "Point-Based Modeling of Human Clothing" (ICCV 2021)

Point-Based Modeling of Human Clothing Paper | Project page | Video This is an official PyTorch code repository of the paper "Point-Based Modeling of

Visual Understanding Lab @ Samsung AI Center Moscow 64 Nov 22, 2022
Pytorch implementation of XRD spectral identification from COD database

XRDidentifier Pytorch implementation of XRD spectral identification from COD database. Details will be explained in the paper to be submitted to NeurI

Masaki Adachi 4 Jan 07, 2023
AttentionGAN for Unpaired Image-to-Image Translation & Multi-Domain Image-to-Image Translation

AttentionGAN-v2 for Unpaired Image-to-Image Translation AttentionGAN-v2 Framework The proposed generator learns both foreground and background attenti

Hao Tang 530 Dec 27, 2022
Implementation of Continuous Sparsification, a method for pruning and ticket search in deep networks

Continuous Sparsification Implementation of Continuous Sparsification (CS), a method based on l_0 regularization to find sparse neural networks, propo

Pedro Savarese 23 Dec 07, 2022
This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust.

Demo BERT ONNX pipeline written in rust This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust. R

Xavier Tao 14 Dec 17, 2022
《Where am I looking at? Joint Location and Orientation Estimation by Cross-View Matching》(CVPR 2020)

This contains the codes for cross-view geo-localization method described in: Where am I looking at? Joint Location and Orientation Estimation by Cross-View Matching, CVPR2020.

41 Oct 27, 2022
FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective

FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective Official implementation of "FL-WBC: Enhan

Jingwei Sun 26 Nov 28, 2022
PyTorch Implementation of "Light Field Image Super-Resolution with Transformers"

LFT PyTorch implementation of "Light Field Image Super-Resolution with Transformers", arXiv 2021. [pdf]. Contributions: We make the first attempt to a

Squidward 62 Nov 28, 2022
This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

ICCV Workshop 2021 VTGAN This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

Sharif Amit Kamran 25 Dec 08, 2022
Complementary Patch for Weakly Supervised Semantic Segmentation, ICCV21 (poster)

CPN (ICCV2021) This is an implementation of Complementary Patch for Weakly Supervised Semantic Segmentation, which is accepted by ICCV2021 poster. Thi

Ferenas 20 Dec 12, 2022
ReferFormer - Official Implementation of ReferFormer

The official implementation of the paper: Language as Queries for Referring Video Object Segmentation Language as Queries for Referring Video Object S

Jonas Wu 232 Dec 29, 2022
My published benchmark for a Kaggle Simulations Competition

Lux AI Working Title Bot Please refer to the Kaggle notebook for the comment section. The comment section contains my explanation on my code structure

Tong Hui Kang 29 Aug 22, 2022