Implementation of the paper "Shapley Explanation Networks"

Overview

Shapley Explanation Networks

Implementation of the paper "Shapley Explanation Networks" at ICLR 2021. Note that this repo heavily uses the experimental feature of named tensors in PyTorch. As it was really confusing to implement the ideas for the authors, we find it tremendously easier to use this feature.

Dependencies

For running only ShapNets, one would mostly only need PyTorch, NumPy, and SciPy.

Usage

For a Shapley Module:

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule

b_size = 3
features = 4
out = 1
dims = ModuleDimensions(
    features=features,
    in_channel=1,
    out_channel=out
)

sm = ShapleyModule(
    inner_function=nn.Linear(features, out),
    dimensions=dims
)
sm(torch.randn(b_size, features), explain=True)

For a Shallow ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, OverlappingShallowShapleyNetwork

batch_size = 32
class_num = 10
dim = 32

overlapping_modules = [
    ShapleyModule(
        inner_function=nn.Sequential(nn.Linear(2, class_num)),
        dimensions=ModuleDimensions(
            features=2, in_channel=1, out_channel=class_num
        ),
    ) for _ in range(dim * (dim - 1) // 2)
]
shallow_shapnet = OverlappingShallowShapleyNetwork(
    list_modules=overlapping_modules
)
inputs = torch.randn(batch_size, dim, ), )
shallow_shapnet(torch.randn(batch_size, dim, ), )
output, bias = shallow_shapnet(inputs, explain=True, )

For a Deep ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, ShallowShapleyNetwork, DeepShapleyNetwork

dim = 32
dim_input_channels = 1
class_num = 10
inputs = torch.randn(32, dim, ), )


dims = ModuleDimensions(
    features=dim,
    in_channel=dim_input_channels,
    out_channel=class_num
)
deep_shapnet = DeepShapleyNetwork(
    list_shapnets=[
        ShallowShapleyNetwork(
            module_dict=nn.ModuleDict({
                "(0, 2)": ShapleyModule(
                    inner_function=nn.Linear(2, class_num),
                    dimensions=ModuleDimensions(
                        features=2, in_channel=1, out_channel=class_num
                    )
                )},
            ),
            dimensions=ModuleDimensions(dim, 1, class_num)
        ),
    ],
)
deep_shapnet(inputs)
outputs = deep_shapnet(inputs, explain=True, )

For a vision model:

import numpy as np
import torch
import torch.nn as nn

# =============================================================================
# Imports {\sc ShapNet}
# =============================================================================
from ShapNet import DeepConvShapNet, ShallowConvShapleyNetwork, ShapleyModule
from ShapNet.utils import ModuleDimensions, NAME_HEIGHT, NAME_WIDTH, \
    process_list_sizes

num_channels = 3
num_classes = 10
height = 32
width = 32
list_channels = [3, 16, 10]
pruning = [0.2, 0.]
kernel_sizes = process_list_sizes([2, (1, 3), ])
dilations = process_list_sizes([1, 2])
paddings = process_list_sizes([0, 0])
strides = process_list_sizes([1, 1])

args = {
    "list_shapnets": [
        ShallowConvShapleyNetwork(
            shapley_module=ShapleyModule(
                inner_function=nn.Sequential(
                    nn.Linear(
                        np.prod(kernel_sizes[i]) * list_channels[i],
                        list_channels[i + 1]),
                    nn.LeakyReLU()
                ),
                dimensions=ModuleDimensions(
                    features=int(np.prod(kernel_sizes[i])),
                    in_channel=list_channels[i],
                    out_channel=list_channels[i + 1])
            ),
            reference_values=None,
            kernel_size=kernel_sizes[i],
            dilation=dilations[i],
            padding=paddings[i],
            stride=strides[i]
        ) for i in range(len(list_channels) - 1)
    ],
    "reference_values": None,
    "residual": False,
    "named_output": False,
    "pruning": pruning
}

dcs = DeepConvShapNet(**args)

Citation

If this is useful, you could cite our work as

@inproceedings{
wang2021shapley,
title={Shapley Explanation Networks},
author={Rui Wang and Xiaoqian Wang and David I. Inouye},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=vsU0efpivw}
}
Owner
Prof. David I. Inouye's research lab at Purdue University.
AI-generated-characters for Learning and Wellbeing

AI-generated-characters for Learning and Wellbeing Click here for the full project page. This repository contains the source code for the paper AI-gen

MIT Media Lab 214 Jan 01, 2023
End-to-end image segmentation kit based on PaddlePaddle.

English | 简体中文 PaddleSeg PaddleSeg has released the new version including the following features: Our team won the 6.2k Jan 02, 2023

NeurIPS 2021 paper 'Representation Learning on Spatial Networks' code

Representation Learning on Spatial Networks This repository is the official implementation of Representation Learning on Spatial Networks. Training Ex

13 Dec 29, 2022
This project is based on RIFE and aims to make RIFE more practical for users by adding various features and design new models

CPM 项目描述 CPM(Chinese Pretrained Models)模型是北京智源人工智能研究院和清华大学发布的中文大规模预训练模型。官方发布了三种规模的模型,参数量分别为109M、334M、2.6B,用户需申请与通过审核,方可下载。 由于原项目需要考虑大模型的训练和使用,需要安装较为复杂

hzwer 190 Jan 08, 2023
A package related to building quasi-fibration symmetries

qf A package related to building quasi-fibration symmetries. If you'd like to learn more about how it works, see the brief explanation and References

Paolo Boldi 1 Dec 01, 2021
Codes for Causal Semantic Generative model (CSG), the model proposed in "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS-21)

Learning Causal Semantic Representation for Out-of-Distribution Prediction This repository is the official implementation of "Learning Causal Semantic

Chang Liu 54 Dec 01, 2022
TensorFlow-based implementation of "Pyramid Scene Parsing Network".

PSPNet_tensorflow Important Code is fine for inference. However, the training code is just for reference and might be only used for fine-tuning. If yo

HsuanKung Yang 323 Dec 20, 2022
This repository contains the implementation of the following paper: Cross-Descriptor Visual Localization and Mapping

Cross-Descriptor Visual Localization and Mapping This repository contains the implementation of the following paper: "Cross-Descriptor Visual Localiza

Mihai Dusmanu 81 Oct 06, 2022
Semi-supervised Learning for Sentiment Analysis

Neural-Semi-supervised-Learning-for-Text-Classification-Under-Large-Scale-Pretraining Code, models and Datasets for《Neural Semi-supervised Learning fo

47 Jan 01, 2023
Scikit-learn compatible estimation of general graphical models

skggm : Gaussian graphical models using the scikit-learn API In the last decade, learning networks that encode conditional independence relationships

213 Jan 02, 2023
details on efforts to dump the Watermelon Games Paprium cart

Reminder, if you like these repos, fork them so they don't disappear https://github.com/ArcadeHustle/WatermelonPapriumDump/fork Big thanks to Fonzie f

Hustle Arcade 29 Dec 11, 2022
Official PyTorch implementation of PS-KD

Self-Knowledge Distillation with Progressive Refinement of Targets (PS-KD) Accepted at ICCV 2021, oral presentation Official PyTorch implementation of

61 Dec 28, 2022
A task-agnostic vision-language architecture as a step towards General Purpose Vision

Towards General Purpose Vision Systems By Tanmay Gupta, Amita Kamath, Aniruddha Kembhavi, and Derek Hoiem Overview Welcome to the official code base f

AI2 79 Dec 23, 2022
Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation

Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation This repository contains code and data f

Zoey Liu 0 Jan 07, 2022
AI Virtual Calculator: This is a simple virtual calculator based on Artificial intelligence.

AI Virtual Calculator: This is a simple virtual calculator that works with gestures using OpenCV. We will use our hand in the air to click on the calc

Md. Rakibul Islam 1 Jan 13, 2022
Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch Reference Paper URL Author: Yi Tay, Dara Bahri, Donald Metzler

Myeongjun Kim 66 Nov 30, 2022
Read number plates with https://platerecognizer.com/

HASS-plate-recognizer Read vehicle license plates with https://platerecognizer.com/ which offers free processing of 2500 images per month. You will ne

Robin 69 Dec 30, 2022
Library for time-series-forecasting-as-a-service.

TIMEX TIMEX (referred in code as timexseries) is a framework for time-series-forecasting-as-a-service. Its main goal is to provide a simple and generi

Alessandro Falcetta 8 Jan 06, 2023
Algo-burn - Script to configure an Algorand address as a "burn" address for one or more ASA tokens

Algorand Burn Address This is a simple script to illustrate how a "burn address"

GSD 5 May 10, 2022
Implicit Model Specialization through DAG-based Decentralized Federated Learning

Federated Learning DAG Experiments This repository contains software artifacts to reproduce the experiments presented in the Middleware '21 paper "Imp

Operating Systems and Middleware Group 5 Oct 16, 2022