Fast Axiomatic Attribution for Neural Networks (NeurIPS*2021)

Overview

Fast Axiomatic Attribution for Neural Networks

License Framework

This is the official repository accompanying the NeurIPS 2021 paper:

R. Hesse, S. Schaub-Meyer, and S. Roth. Fast axiomatic attribution for neural networks. NeurIPS, 2021, to appear.

Paper | Preprint (arXiv) | Project Page | Video

The repository contains:

  • Pre-trained -DNN (X-DNN) variants of popular image classification models obtained by removing the bias term of each layer
  • Detailed information on how to easily compute axiomatic attributions in closed form for your own project
  • PyTorch code to reproduce the main experiments in the paper

Pretrained Models

Removing the bias from different image classification models has a surpringly minor impact on the predictive accuracy of the models while allowing to efficiently compute axiomatic attributions. Results of popular models with and without bias term (regular vs. X-) on the ImageNet validation split are:

Model Top-5 Accuracy Download
AlexNet 79.21 alexnet_model_best.pth.tar
X-AlexNet 78.54 xalexnet_model_best.pth.tar
VGG16 90.44 vgg16_model_best.pth.tar
X-VGG16 90.25 xvgg16_model_best.pth.tar
ResNet-50 92.56 fixup_resnet50_model_best.pth.tar
X-ResNet-50 91.12 xfixup_resnet50_model_best.pth.tar

Using X-Gradient in Your Own Project

In the following we illustrate how to efficiently compute axiomatic attributions for X-DNNs. For a detailed example please see demo.ipynb.

First, make sure that requires_grad of your input is set to True and run a forward pass:

inputs.requires_grad = True

# forward pass
outputs = model(inputs)

Next, you can compute X-Gradient via:

# compute attribution
target_outputs = torch.gather(outputs, 1, target.unsqueeze(-1))
gradients = torch.autograd.grad(torch.unbind(target_outputs), inputs, create_graph=True)[0] # set to false if attribution is only used for evaluation
xgradient_attributions = inputs * gradients

If the attribution is only used for evaluation you can set create_graph to False. If you want to use the attribution for training, e.g., for training with attribution priors, you can define attribution_prior() and update the weights of your model:

loss1 = criterion(outputs, target) # standard loss
loss2 = attribution_prior(xgradient_attributions) # attribution prior    

loss = loss1 + lambda * loss2 # set weighting factor for loss2

optimizer.zero_grad()
loss.backward()
optimizer.step()

Reproducing Experiments

The code and a README with detailed instructions on how to reproduce the results from experiments in Sec 4.1, Sec 4.2, and Sec 4.4. of our paper can be found in the imagenet folder. To reproduce the results from the experiment in Sec 4.3. please refer to the sparsity folder.

Prerequisites

  • Clone the repository: git clone https://github.com/visinf/fast-axiomatic-attribution.git
  • Set up environment
    • add the required conda channels and create new environment:
    • conda config --add channels pytorch
    • conda config --add channels anaconda
    • conda config --add channels pipy
    • conda config --add channels conda-forge
    • conda create --name fast-axiomatic-attribution --file requirements.txt
  • download ImageNet (ILSVRC2012)

Acknowledgments

We would like to thank the contributors of the following repositories for using parts of their publicly available code:

Citation

If you find our work helpful please consider citing

@inproceedings{Hesse:2021:FAA,
  title     = {Fast Axiomatic Attribution for Neural Networks},
  author    = {Hesse, Robin and Schaub-Meyer, Simone and Roth, Stefan},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  volume    = {34},
  year      = {2021}
}
Wafer Fault Detection using MlOps Integration

Wafer Fault Detection using MlOps Integration This is an end to end machine learning project with MlOps integration for predicting the quality of wafe

Sethu Sai Medamallela 0 Mar 11, 2022
Pytorch Implementation of Interaction Networks for Learning about Objects, Relations and Physics

Interaction-Network-Pytorch Pytorch Implementraion of Interaction Networks for Learning about Objects, Relations and Physics. Interaction Network is a

117 Nov 05, 2022
Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.

This is the Vowpal Wabbit fast online learning code. Why Vowpal Wabbit? Vowpal Wabbit is a machine learning system which pushes the frontier of machin

Vowpal Wabbit 8.1k Jan 06, 2023
DecoupledNet is semantic segmentation system which using heterogeneous annotations

DecoupledNet: Decoupled Deep Neural Network for Semi-supervised Semantic Segmentation Created by Seunghoon Hong, Hyeonwoo Noh and Bohyung Han at POSTE

Hyeonwoo Noh 74 Sep 22, 2021
Official Implementation and Dataset of "PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask and Group-Level Consistency", CVPR 2021

Portrait Photo Retouching with PPR10K Paper | Supplementary Material PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask an

184 Dec 11, 2022
Implementation of PersonaGPT Dialog Model

PersonaGPT An open-domain conversational agent with many personalities PersonaGPT is an open-domain conversational agent cpable of decoding personaliz

ILLIDAN Lab 42 Jan 01, 2023
Evaluating AlexNet features at various depths

Linear Separability Evaluation This repo provides the scripts to test a learned AlexNet's feature representation performance at the five different con

Yuki M. Asano 32 Dec 30, 2022
A Loss Function for Generative Neural Networks Based on Watson’s Perceptual Model

This repository contains the similarity metrics designed and evaluated in the paper, and instructions and code to re-run the experiments. Implementation in the deep-learning framework PyTorch

Steffen 86 Dec 27, 2022
NDE: Climate Modeling with Neural Diffusion Equation, ICDM'21

Climate Modeling with Neural Diffusion Equation Introduction This is the repository of our accepted ICDM 2021 paper "Climate Modeling with Neural Diff

Jeehyun Hwang 5 Dec 18, 2022
Survival analysis in Python

What is survival analysis and why should I learn it? Survival analysis was originally developed and applied heavily by the actuarial and medical commu

Cameron Davidson-Pilon 2k Jan 08, 2023
A curated list of awesome game datasets, and tools to artificial intelligence in games

🎮 Awesome Game Datasets In computer science, Artificial Intelligence (AI) is intelligence demonstrated by machines. Its definition, AI research as th

Leonardo Mauro 454 Jan 03, 2023
Conditional Generative Adversarial Networks (CGAN) for Mobility Data Fusion

This code implements the paper, Kim et al. (2021). Imputing Qualitative Attributes for Trip Chains Extracted from Smart Card Data Using a Conditional Generative Adversarial Network. Transportation Re

Eui-Jin Kim 2 Feb 03, 2022
[Arxiv preprint] Causality-inspired Single-source Domain Generalization for Medical Image Segmentation (code&data-processing pipeline)

Causality-inspired Single-source Domain Generalization for Medical Image Segmentation Arxiv preprint Repository under construction. Might still be bug

Cheng 31 Dec 27, 2022
a generic C++ library for image analysis

VIGRA Computer Vision Library Copyright 1998-2013 by Ullrich Koethe This file is part of the VIGRA computer vision library. You may use,

Ullrich Koethe 378 Dec 30, 2022
Meli Data Challenge 2021 - First Place Solution

My solution for the Meli Data Challenge 2021

Matias Moreyra 23 Mar 09, 2022
Towards Debiasing NLU Models from Unknown Biases

Towards Debiasing NLU Models from Unknown Biases Abstract: NLU models often exploit biased features to achieve high dataset-specific performance witho

Ubiquitous Knowledge Processing Lab 22 Jun 14, 2022
Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph".

multilingual-mrc-isdg Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph". This r

Liyan 5 Dec 07, 2022
Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

ChongjianGE 89 Dec 02, 2022
PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis

Daft-Exprt - PyTorch Implementation PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis The

Keon Lee 47 Dec 18, 2022
[AAAI 2022] Separate Contrastive Learning for Organs-at-Risk and Gross-Tumor-Volume Segmentation with Limited Annotation

A paper Introduction This is an official release of the paper Separate Contrastive Learning for Organs-at-Risk and Gross-Tumor-Volume Segmentation wit

Jiacheng Wang 14 Dec 08, 2022