TimeSHAP explains Recurrent Neural Network predictions.

Overview

TimeSHAP

TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes event/timestamp- feature-, and cell-level attributions. As sequences can be arbitrarily long, TimeSHAP also implements a pruning algorithm based on Shapley Values, that finds a subset of consecutive, recent events that contribute the most to the decision.

This repository is the code implementation of the TimeSHAP algorithm present in the paper TimeSHAP: Explaining Recurrent Models through Sequence Perturbations published at KDD 2021.

Links to the paper here, and to the video presentation here.

Install TimeSHAP

Clone the repository into a local directory using:

git clone https://github.com/feedzai/timeshap.git

Install the package using pip:

pip install timeshap

To test your installation, start a Python session in your terminal using

python

And import TimeSHAP

import timeshap

TimeSHAP in 30 seconds

Inputs

  • Model being explained;
  • Instance(s) to explain;
  • Background instance.

Outputs

  • Local pruning output; (explaining a single instance)
  • Local event explanations; (explaining a single instance)
  • Local feature explanations; (explaining a single instance)
  • Global pruning statistics; (explaining multiple instances)
  • Global event explanations; (explaining multiple instances)
  • Global feature explanations; (explaining multiple instances)

Model Interface

In order for TimeSHAP to explain a model, an entry point must be provided. This Callable entry point must receive a 3-D numpy array, (#sequences; #sequence length; #features) and return a 2-D numpy array (#sequences; 1) with the corresponding score of each sequence. In addition, to make TimeSHAP more optimized, it is possible to return the hidden state of the model together with the score (if applicable), although this is optional.

TimeSHAP is able to explain any black-box model as long as it complies with the previously described interface, including both PyTorch and TensorFlow models, both examplified in our tutorials (PyTorch, TensorFlow).

Example provided in our tutorials:

  • TensorFLow
model = tf.keras.models.Model(inputs=inputs, outputs=ff2)
f = lambda x: model.predict(x)
  • Pytorch - (Example where model receives and returns hidden states)
model_wrapped = TorchModelWrapper(model)
f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)
Model Wrappers

In order to facilitate the interface between models and TimeSHAP, TimeSHAP implements ModelWrappers. These wrappers, used on the PyTorch tutorial notebook, allow for greater flexibility of explained models as they allow:

  • Batching logic: useful when using very large inputs or NSamples, which cannot fit on GPU memory, and therefore batching mechanisms are required;
  • Input format/type: useful when your model does not work with numpy arrays. This is the case of our provided PyToch example;

TimeSHAP Explanation Methods

TimeSHAP offers several methods to use depending on the desired explanations. Local methods provide detailed view of a model decision corresponding to a specific sequence being explained. Global methods aggregate local explanations of a given dataset to present a global view of the model.

Local Explanations

Pruning

local_pruning() performs the pruning algorithm on a given sequence with a given user defined tolerance and returns the pruning index along the information for plotting.

plot_temp_coalition_pruning() plots the pruning algorithm information calculated by local_pruning().

Event level explanations

local_event() calculates event level explanations of a given sequence with the user-given parameteres and returns the respective event-level explanations.

plot_event_heatmap() plots the event-level explanations calculated by local_event().

Feature level explanations

local_feat() calculates feature level explanations of a given sequence with the user-given parameteres and returns the respective feature-level explanations.

plot_feat_barplot() plots the feature-level explanations calculated by local_feat().

Cell level explanations

local_cell_level() calculates cell level explanations of a given sequence with the respective event- and feature-level explanations and user-given parameteres, returing the respective cell-level explanations.

plot_cell_level() plots the feature-level explanations calculated by local_cell_level().

Local Report

local_report() calculates TimeSHAP local explanations for a given sequence and plots them.

Global Explanations

Global pruning statistics

prune_all() performs the pruning algorithm on multiple given sequences.

pruning_statistics() calculates the pruning statistics for several user-given pruning tolerances using the pruning data calculated by prune_all(), returning a pandas.DataFrame with the statistics.

Global event level explanations

event_explain_all() calculates TimeSHAP event level explanations for multiple instances given user defined parameters.

plot_global_event() plots the global event-level explanations calculated by event_explain_all().

Global feature level explanations

feat_explain_all() calculates TimeSHAP feature level explanations for multiple instances given user defined parameters.

plot_global_feat() plots the global feature-level explanations calculated by feat_explain_all().

Global report

global_report() calculates TimeSHAP explanations for multiple instances, aggregating the explanations on two plots and returning them.

Tutorial

In order to demonstrate TimeSHAP interfaces and methods, you can consult AReM.ipynb. In this tutorial we get an open-source dataset, process it, train Pytorch recurrent model with it and use TimeSHAP to explain it, showcasing all previously described methods.

Additionally, we also train a TensorFlow model on the same dataset AReM_TF.ipynb.

Repository Structure

Citing TimeSHAP

@inproceedings{bento2021timeshap,
    author = {Bento, Jo\~{a}o and Saleiro, Pedro and Cruz, Andr\'{e} F. and Figueiredo, M\'{a}rio A.T. and Bizarro, Pedro},
    title = {TimeSHAP: Explaining Recurrent Models through Sequence Perturbations},
    year = {2021},
    isbn = {9781450383325},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    url = {https://doi.org/10.1145/3447548.3467166},
    doi = {10.1145/3447548.3467166},
    booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining},
    pages = {2565–2573},
    numpages = {9},
    keywords = {SHAP, Shapley values, TimeSHAP, XAI, RNN, explainability},
    location = {Virtual Event, Singapore},
    series = {KDD '21}
}
Comments
  • Error in running the example notebook (AReM_TF)

    Error in running the example notebook (AReM_TF)

    Nice work! I have been trying to run one of the tutorial notebooks in the repository (i.e., AReM_TF), but I faced an error. The notebook chunk that produces the error is:

    from timeshap.explainer import local_report, local_pruning
    
    pruning_dict = {'tol': 0.025}
    event_dict = {'rs': 42, 'nsamples': 320}
    feature_dict = {'rs': 42, 'nsamples': 320, 'feature_names': model_features, 'plot_features': plot_feats}
    cell_dict = {'rs': 42, 'nsamples': 320, 'top_x_feats': 2, 'top_x_events': 2}
    local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict,cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    

    and the produced error is as follows: image image

    It would be great if you could help me with this error.

    opened by aminnayebi 7
  • Unable to install timeshap package

    Unable to install timeshap package

    Hello @feedzaiadmin , @saleiro ,

    I am unable to run command "pip install timeshap" on my ubuntu system. It throws me below error:

    ERROR: Could not find a version that satisfies the requirement timeshap (from versions: none) ERROR: No matching distribution found for timeshap

    Seems to me a compatibility issue. I have tried with python versions - 3.6,3.8. Is there any specific version which supports it? Will be waiting for response.

    Thanks

    opened by vishants98 5
  • How to adapt the transformation function to account for variable sequence length?

    How to adapt the transformation function to account for variable sequence length?

    I am trying to use TimeSHAP on my use case. Per my understanding, in AReM example, the way you transform the data using the df_to_numpy function is to make a prediction for the last value of the sequence – see the screen below:

    image In the case of AReM tutorial data, the predictions are based on the whole sequence - all rows (rows ID 1-10) are being used for sequence ID 1 (light blue color) and the predictions are made for the Timestamp 10 (dark blue color; rows id 10). Later the light orange color is used (Row IDs 11-20) to predict a label marked as dark orange color (Row ID 20).

    In the case of my use case, the model predicts on a rolling-window basis and I would need predictions for every row (not only for a sequence). See the screen and explanation below. image Let's say my rolling window is 6 and Row IDs 1-6 (light green) are used to predict row 7 (dark green), later Row IDs 2-7 (light grey) are being used to predict Row ID 8 (dark grey), etc. When a new Sequence starts, we repeat the process, so we take Row IDs 11-16 and predict Row ID 17, etc. For my use case, it's important to evaluate the predictions for every Row ID, not only for the whole sequence.

    The problem which I am facing is that when I try to run the function get_avg_score_with_avg_event on the data defined as in the picture above I am getting the following error: image

    The way my data is transformed from 2D into 3D format is defined by the function below: image

    My question is whether it’s possible to make TimeSHAP work for the data which is transformed in a way described in my use case? When I use the transformation which is defined in your function df_to_numpy, I am not getting an error, however, it is not adapted to my use case.

    opened by grzechowiak 4
  • Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Foremost, thanks for the library, really great job!

    I am trying to reproduce your TimeSHAP Tutorial for TF and I am having an issue in the section for Global Explanations when running the global_report() function - screen below:

    image

    The error refers to encoding the \u2264 character which is a sign: <=. I was trying to solve that myself by modifying the pruning.py function according to the error by adding encoding="utf-8" in line 326 with open(file_path, 'a', newline='') as file:, however it didn't solve the problem. Any advice is very welcome!

    Also, for consistency, I want to mention that I had a problem loading the data - showing the error screens below. I was able to solve the problem only by deleting 2 datasets: cycling/dataset9.csv and cycling/dataset14csv and the rest of the code worked.

    1/2 image 2/2 image

    opened by grzechowiak 4
  • Timeshap for regression

    Timeshap for regression

    I am working on a time series forecasting problem using LSTM. Can I use timeshap for such a regression problem? Do you by chance have a demo for regression?

    Thanks

    opened by mgorjis 3
  • Plot Coalition Pruning is not working

    Plot Coalition Pruning is not working

    Hi, I am doing a project as part of a Master Thesis: I was testing your introductory notebook with the Tensorflow implementation, but this is not working because of an error raised by one of yours plot functions. The funny thing is that your other notebook, i.e., the one with the torch model, is working fine. I managed to find the issue and it was raised by an inner method called solve_negatives_method in src/timeshap/plot/pruning.py which would raise an InvalidKeyError caused by this specific row:

        df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    As far I understood, we are passing a list, usually made of only one value, to the pandas.DataFrame.at method which requires to pass an integer parameter and a column identifier. As such, one solution that works now could be:

        df.at[corresponding_row.index[0], 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    I want to thank you very much for the work done, and I will be happy to discuss further result of my thesis with you. :) Hoping for the best and looking forward to hear from you, Eric

    opened by Erhtric 2
  • TimeSHAP for text?

    TimeSHAP for text?

    I'm working with a 1-layer GRU for text classification that takes BERT embeddings at the input. Each input sequence is of the shape (sequence length, bert-embedding-dimension). I'm looking for word level attribution scores for each sequence's prediction. Currently with the captum integrated gradients and occlusion explainers, I get attribution scores that are almost always the last few words of each sequence. This seems like it's stemming from the directional processing of GRU - any thoughts?

    Do you think TimeSHAP would be applicable for my use case? I suppose I could consider each word as an event and each embedding dimension as a feature, then I could use the event level local explanations from the library? However, note that in my case, the events (i.e words) from the beginning of the sequence could be more important than those at the end of the sequence (i.e. most recent ones) - this violates the assumptions you use for your approximation (i.e pruning), so perhaps it's not applicable to text?

    opened by itsmemala 2
  • Intuition around pruning/baseline selection

    Intuition around pruning/baseline selection

    While calculating a global report, I'm currently running into errors that "Score difference between baseline and instance is too low < 0.1...Consider choosing another baseline." My baselines have been the average_event and average_sequence. I also notice that occasionally the values in the error change with different pruning tolerance, but not in a consistent way. Do you have advice for dealing with this? Thanks.

    opened by xydisla 2
  • CNN model

    CNN model

    I currently have a CNN model, and previously had to do some strange hacking to get time series importance values. Your package now shines a better light on this issue. For multivariate forecasts CNNs still fair well, sometimes better than RNNs, from what I can see there is no reason why using CNNs won't working using your software. Let me know if I have that wrong.

    opened by firmai 2
  • How to speed up TIMESHAP computation

    How to speed up TIMESHAP computation

    Hi all!

    The package itself is really interesting and intuitive to use. But I want to speed up TIMESHAP computation, Can I use gpu to calculate shapley values? I used a device parameter in TorchModelWrapper, but efficiency of GPU is too low to accelerate TIMESHAP computation. Any suggestion would be appreciated.

    opened by Changshu135 2
  • Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Hi all! When executing your TF example notebook without any changes, it fails at this line: local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)

    with the following error: InvalidIndexError: Int64Index([1], dtype='int64')

    Stack trace:

    TypeError                                 Traceback (most recent call last)
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3621, in Index.get_loc(self, key, method, tolerance)
       3620 try:
    -> 3621     return self._engine.get_loc(casted_key)
       3622 except KeyError as err:
    
    File pandas/_libs/index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()
    
    File pandas/_libs/index.pyx:142, in pandas._libs.index.IndexEngine.get_loc()
    
    TypeError: 'Int64Index([1], dtype='int64')' is an invalid key
    
    During handling of the above exception, another exception occurred:
    
    InvalidIndexError                         Traceback (most recent call last)
    Input In [27], in <cell line: 7>()
          5 feature_dict = {'rs': 42, 'nsamples': 32000, 'feature_names': model_features, 'plot_features': plot_feats}
          6 cell_dict = {'rs': 42, 'nsamples': 32000, 'top_x_feats': 2, 'top_x_events': 2}
    ----> 7 local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    
    File ~/temp/timeshap/src/timeshap/explainer/local_methods.py:139, in local_report(f, data, pruning_dict, event_dict, feature_dict, cell_dict, entity_uuid, entity_col, time_col, model_features, baseline, verbose)
        137 pruning_idx = data.shape[1] + coal_prun_idx
        138 plot_lim = max(abs(coal_prun_idx)+10, 40)
    --> 139 pruning_plot = plot_temp_coalition_pruning(coal_plot_data, coal_prun_idx, plot_lim)
        141 event_data = local_event(f, data, event_dict, entity_uuid, entity_col, baseline, pruning_idx)
        142 event_plot = plot_event_heatmap(event_data)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:53, in plot_temp_coalition_pruning(df, pruned_idx, plot_limit, solve_negatives)
         51 df = df[df['t (event index)'] >= -plot_limit]
         52 if solve_negatives:
    ---> 53     df = solve_negatives_method(df)
         55 base = (alt.Chart(df).encode(
         56     x=alt.X("t (event index)", axis=alt.Axis(title='t (event index)', labelFontSize=15,
         57                           titleFontSize=15)),
       (...)
         70 )
         71 )
         73 area_chart = base.mark_area(opacity=0.5)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:47, in plot_temp_coalition_pruning.<locals>.solve_negatives_method(df)
         45 for idx, row in negative_values.iterrows():
         46     corresponding_row = df[np.logical_and(df['t (event index)'] == row['t (event index)'], ~(df['Coalition'] == row['Coalition']))]
    ---> 47     df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
         48     df.at[idx, 'Shapley Value'] = 0
         49 return df
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2281, in _AtIndexer.__setitem__(self, key, value)
       2278     self.obj.loc[key] = value
       2279     return
    -> 2281 return super().__setitem__(key, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2236, in _ScalarAccessIndexer.__setitem__(self, key, value)
       2233 if len(key) != self.ndim:
       2234     raise ValueError("Not enough indexers for scalar access (setting)!")
    -> 2236 self.obj._set_value(*key, value=value, takeable=self._takeable)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/frame.py:3869, in DataFrame._set_value(self, index, col, value, takeable)
       3867 else:
       3868     series = self._get_item_cache(col)
    -> 3869     loc = self.index.get_loc(index)
       3871 # setitem_inplace will do validation that may raise TypeError
       3872 #  or ValueError
       3873 series._mgr.setitem_inplace(loc, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3628, in Index.get_loc(self, key, method, tolerance)
       3623         raise KeyError(key) from err
       3624     except TypeError:
       3625         # If we have a listlike key, _check_indexing_error will raise
       3626         #  InvalidIndexError. Otherwise we fall through and re-raise
       3627         #  the TypeError.
    -> 3628         self._check_indexing_error(key)
       3629         raise
       3631 # GH#42269
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:5637, in Index._check_indexing_error(self, key)
       5633 def _check_indexing_error(self, key):
       5634     if not is_scalar(key):
       5635         # if key is not a scalar, directly raise an error (the code below
       5636         # would convert to numpy arrays and raise later any way) - GH29926
    -> 5637         raise InvalidIndexError(key)
    
    InvalidIndexError: Int64Index([1], dtype='int64')
    
    opened by JulianKlug 2
Owner
Feedzai
Feedzai
Learning Neural Network Subspaces

Learning Neural Network Subspaces Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin,

Apple 117 Nov 17, 2022
Self-Supervised Learning

Self-Supervised Learning Features self_supervised offers features like modular framework support for multi-gpu training using PyTorch Lightning easy t

Robin 1 Dec 14, 2021
U-Net Implementation: Convolutional Networks for Biomedical Image Segmentation" using the Carvana Image Masking Dataset in PyTorch

U-Net Implementation By Christopher Ley This is my interpretation and implementation of the famous paper "U-Net: Convolutional Networks for Biomedical

Christopher Ley 1 Jan 06, 2022
iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis

iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis Andreas Bl

CompVis Heidelberg 36 Dec 25, 2022
PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in clustering (CVPR2021)

PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in Clustering Jang Hyun Cho1, Utkarsh Mall2, Kavita Bala2, Bharath Harihar

Jang Hyun Cho 164 Dec 30, 2022
Code for our ICASSP 2021 paper: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

SA-Net: Shuffle Attention for Deep Convolutional Neural Networks (paper) By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software T

Qing-Long Zhang 199 Jan 08, 2023
Over-the-Air Ensemble Inference with Model Privacy

Over-the-Air Ensemble Inference with Model Privacy This repository contains simulations for our private ensemble inference method. Installation Instal

Selim Firat Yilmaz 1 Jun 29, 2022
Automatic differentiation with weighted finite-state transducers.

GTN: Automatic Differentiation with WFSTs Quickstart | Installation | Documentation What is GTN? GTN is a framework for automatic differentiation with

100 Dec 29, 2022
An open source Jetson Nano baseboard and tools to design your own.

My Jetson Nano Baseboard This basic baseboard gives the user the foundation and the flexibility to design their own baseboard for the Jetson Nano. It

NVIDIA AI IOT 57 Dec 29, 2022
This is a file about Unet implemented in Pytorch

Unet this is an implemetion of Unet in Pytorch and it's architecture is as follows which is the same with paper of Unet component of Unet Convolution

Dragon 1 Dec 03, 2021
An Active Automata Learning Library Written in Python

AALpy An Active Automata Learning Library AALpy is a light-weight active automata learning library written in pure Python. You can start learning auto

TU Graz - SAL Dependable Embedded Systems Lab (DES Lab) 78 Dec 30, 2022
Progressive Image Deraining Networks: A Better and Simpler Baseline

Progressive Image Deraining Networks: A Better and Simpler Baseline [arxiv] [pdf] [supp] Introduction This paper provides a better and simpler baselin

190 Dec 01, 2022
Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Generated Images"

Reverse_Engineering_GMs Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Gener

100 Dec 18, 2022
Use tensorflow to implement a Deep Neural Network for real time lane detection

LaneNet-Lane-Detection Use tensorflow to implement a Deep Neural Network for real time lane detection mainly based on the IEEE IV conference paper "To

MaybeShewill-CV 1.9k Jan 08, 2023
SARS-Cov-2 Recombinant Finder for fasta sequences

Sc2rf - SARS-Cov-2 Recombinant Finder Pronounced: Scarf What's this? Sc2rf can search genome sequences of SARS-CoV-2 for potential recombinants - new

Lena Schimmel 41 Oct 03, 2022
Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking

Revisiting Oxford and Paris: Large-Scale Image Retrieval Benchmarking We revisit and address issues with Oxford 5k and Paris 6k image retrieval benchm

Filip Radenovic 188 Dec 17, 2022
MPRNet-Cloud-removal: Progressive cloud removal

MPRNet-Cloud-removal Progressive cloud removal Requirements 1.Pytorch = 1.0 2.Python 3 3.NVIDIA GPU + CUDA 9.0 4.Tensorboard Installation 1.Clone the

Semi 95 Dec 18, 2022
Backdoor Attack through Frequency Domain

Backdoor Attack through Frequency Domain DEPENDENCIES python==3.8.3 numpy==1.19.4 tensorflow==2.4.0 opencv==4.5.1 idx2numpy==1.2.3 pytorch==1.7.0 Data

5 Jun 18, 2022
Dense Deep Unfolding Network with 3D-CNN Prior for Snapshot Compressive Imaging, ICCV2021 [PyTorch Code]

Dense Deep Unfolding Network with 3D-CNN Prior for Snapshot Compressive Imaging, ICCV2021 [PyTorch Code]

Jian Zhang 20 Oct 24, 2022
Trading Gym is an open source project for the development of reinforcement learning algorithms in the context of trading.

Trading Gym Trading Gym is an open-source project for the development of reinforcement learning algorithms in the context of trading. It is currently

Dimitry Foures 535 Nov 15, 2022