SimplEx - Explaining Latent Representations with a Corpus of Examples

Overview

SimplEx - Explaining Latent Representations with a Corpus of Examples

image

Code Author: Jonathan Crabbé ([email protected])

This repository contains the implementation of SimplEx, a method to explain the latent representations of black-box models with the help of a corpus of examples. For more details, please read our NeurIPS 2021 paper: 'Explaining Latent Representations with a Corpus of Examples'.

Installation

  1. Clone the repository
  2. Create a new virtual environment with Python 3.8
  3. Run the following command from the repository folder:
    pip install -r requirements.txt #install requirements

When the packages are installed, SimplEx can directly be used.

Toy example

Bellow, you can find a toy demonstration where we make a corpus decomposition of test examples representations. All the relevant code can be found in the file simplex.

from explainers.simplex import Simplex
from models.base import BlackBox

# Get the model and the examples
model = BlackBox() # Model should have the BlackBox interface
corpus_inputs = get_corpus() # A tensor of corpus inputs
test_inputs = get_test() # A set of inputs to explain

# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs) 
test_latents = model.latent_representation(test_inputs)

# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs, 
                  corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs, 
            test_latent_reps=test_latents,
            reg_factor=0)

# Get the weights of each corpus decomposition
weights = simplex.weights

We get a tensor weights that can be interpreted as follows: weights[i,c] = weight of corpus example c in the decomposition of example i.

We can get the importance of each corpus feature for the decomposition of a given example i in the following way:

# Compute the Integrated Jacobian for a particular example
i = 42
input_baseline = get_baseline() # Baseline tensor of the same shape as corpus_inputs
simplex.jacobian_projections(test_id=i, model=model,
                             input_baseline=input_baseline)

result = simplex.decompose(i)

We get a list result where each element of the list corresponds to a corpus example. This list is sorted by decreasing order of importance in the corpus decomposition. Each element of the list is a tuple structured as follows:

w_c, x_c, proj_jacobian_c = result[c]

Where w_c corresponds to the weight weights[i,c], x_c corresponds to corpus_inputs[c] and proj_jacobian is a tensor such that proj_jacobian_c[k] is the Projected Jacobian of feature k from corpus example c.

Reproducing the paper results

Reproducing MNIST Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.mnist -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing Prostate Cancer Approximation Quality Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Copy the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.prostate_cancer -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.prostate.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing Prostate Cancer Outlier Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Make sure that the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.prostate_cancer -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.prostate.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing MNIST Jacobian Projection Significance Experiment

  1. Run the following script
python -m experiments.mnist -experiment "jacobian_corruption" 

2.The resulting plots and data are saved here.

Reproducing MNIST Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.mnist -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing MNIST Influence Function Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.mnist -experiment "influence" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.influence.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Note: some problems can appear with the package Pytorch Influence Functions. If this is the case, please change calc_influence_function.py in the following way:

343: influences.append(tmp_influence) ==> influences.append(tmp_influence.cpu())
438: influences_meta['test_sample_index_list'] = sample_list ==> #influences_meta['test_sample_index_list'] = sample_list

Reproducing AR Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.time_series -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.ar.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing AR Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.time_series -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.ar.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Citing

If you use this code, please cite the associated paper:

Put citation here when ready
Owner
Jonathan Crabbé
I am currently doing a PhD in Explainable AI at the Department of Applied Mathematics and Theoretical Physics (DAMTP) of the University of Cambridge.
Jonathan Crabbé
Official implementation of Deep Reparametrization of Multi-Frame Super-Resolution and Denoising

Deep-Rep-MFIR Official implementation of Deep Reparametrization of Multi-Frame Super-Resolution and Denoising Publication: Deep Reparametrization of M

Goutam Bhat 39 Jan 04, 2023
An implementation of EWC with PyTorch

EWC.pytorch An implementation of Elastic Weight Consolidation (EWC), proposed in James Kirkpatrick et al. Overcoming catastrophic forgetting in neural

Ryuichiro Hataya 166 Dec 22, 2022
This git repo contains the implementation of my ML project on Heart Disease Prediction

Introduction This git repo contains the implementation of my ML project on Heart Disease Prediction. This is a real-world machine learning model/proje

Aryan Dutta 1 Feb 02, 2022
[ICCV 2021 Oral] NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo

NerfingMVS Project Page | Paper | Video | Data NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo Yi Wei, Shaohui

Yi Wei 369 Dec 24, 2022
Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Medical-Transformer Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" About this repo: This repo

Jeya Maria Jose 615 Dec 25, 2022
Official PyTorch implementation of DD3D: Is Pseudo-Lidar needed for Monocular 3D Object detection? (ICCV 2021), Dennis Park*, Rares Ambrus*, Vitor Guizilini, Jie Li, and Adrien Gaidon.

DD3D: "Is Pseudo-Lidar needed for Monocular 3D Object detection?" Install // Datasets // Experiments // Models // License // Reference Full video Offi

Toyota Research Institute - Machine Learning 364 Dec 27, 2022
[AAAI 2021] MVFNet: Multi-View Fusion Network for Efficient Video Recognition

MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021) Overview We release the code of the MVFNet (Multi-View Fusion Network).

Wenhao Wu 114 Nov 27, 2022
A Python framework for developing parallelized Computational Fluid Dynamics software to solve the hyperbolic 2D Euler equations on distributed, multi-block structured grids.

pyHype: Computational Fluid Dynamics in Python pyHype is a Python framework for developing parallelized Computational Fluid Dynamics software to solve

Mohamed Khalil 21 Nov 22, 2022
Diagnostic tests for linguistic capacities in language models

LM diagnostics This repository contains the diagnostic datasets and experimental code for What BERT is not: Lessons from a new suite of psycholinguist

61 Jan 02, 2023
ruptures: change point detection in Python

Welcome to ruptures ruptures is a Python library for off-line change point detection. This package provides methods for the analysis and segmentation

Charles T. 1.1k Jan 03, 2023
Mind the Trade-off: Debiasing NLU Models without Degrading the In-distribution Performance

Models for natural language understanding (NLU) tasks often rely on the idiosyncratic biases of the dataset, which make them brittle against test cases outside the training distribution.

Ubiquitous Knowledge Processing Lab 22 Jan 02, 2023
Official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.

GLIDE This is the official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing w

OpenAI 2.9k Jan 04, 2023
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 01, 2023
WormMovementSimulation - 3D Simulation of Worm Body Movement with Neurons attached to its body

Generate 3D Locomotion Data This module is intended to create 2D video trajector

1 Aug 09, 2022
Code repository for our paper regarding the L3D dataset.

The Large Labelled Logo Dataset (L3D): A Multipurpose and Hand-Labelled Continuously Growing Dataset Website: https://lhf-labs.github.io/tm-dataset Da

LHF Labs 9 Dec 14, 2022
Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Scribble-Supervised LiDAR Semantic Segmentation Dataset and code release for the paper Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORA

102 Dec 25, 2022
Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Nikhil Iyer 9 Dec 27, 2022
PyTorch code to run synthetic experiments.

Code repository for Invariant Risk Minimization Source code for the paper: @article{InvariantRiskMinimization, title={Invariant Risk Minimization}

Facebook Research 345 Dec 12, 2022
Single-Shot Motion Completion with Transformer

Single-Shot Motion Completion with Transformer 👉 [Preprint] 👈 Abstract Motion completion is a challenging and long-discussed problem, which is of gr

FuxiCV 78 Dec 29, 2022
Small repo describing how to use Hugging Face's Wav2Vec2 with PyCTCDecode

🤗 Transformers Wav2Vec2 + PyCTCDecode Introduction This repo shows how 🤗 Transformers can be used in combination with kensho-technologies's PyCTCDec

Patrick von Platen 102 Oct 22, 2022