PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

Overview

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

The paper: https://arxiv.org/abs/1704.03296

What makes the deep learning network think the image label is 'pug, pug-dog' and 'tabby, tabby cat':

Dog Cat

A perturbation of the dog that caused the dog category score to vanish:

Perturbed

What makes the deep learning network think the image label is 'flute, transverse flute':

Flute


Usage: python explain.py <path_to_image>

This is a PyTorch impelentation of

"Interpretable Explanations of Black Boxes by Meaningful Perturbation. Ruth Fong, Andrea Vedaldi" with some deviations.

This uses VGG19 from torchvision. It will be downloaded when used for the first time.

This learns a mask of pixels that explain the result of a black box. The mask is learned by posing an optimization problem and solving directly for the mask values.

This is different than other visualization techniques like Grad-CAM that use heuristics like high positive gradient values as an indication of relevance to the network score.

In our case the black box is the VGG19 model, but this can use any differentiable model.


How it works

Equation

Taken from the paper https://arxiv.org/abs/1704.03296

The goal is to solve for a mask that explains why did the network output a score for a certain category.

We create a low resolution (28x28) mask, and use it to perturb the input image to a deep learning network.

The perturbation combines a blurred version of the image, the regular image, and the up-sampled mask.

Wherever the mask contains low values, the input image will become more blurry.

We want to optimize for the next properties:

  1. When using the mask to blend the input image and it's blurred versions, the score of the target category should drop significantly. The evidence of the category should be removed!
  2. The mask should be sparse. Ideally the mask should be the minimal possible mask to drop the category score. This translates to a L1(1 - mask) term in the cost function.
  3. The mask should be smooth. This translates to a total variation regularization in the cost function.
  4. The mask shouldn't over-fit the network. Since the network activations might contain a lot of noise, it can be easy for the mask to just learn random values that cause the score to drop without being visually coherent. In addition to the other terms, this translates to solving for a lower resolution 28x28 mask.

Deviations from the paper

The paper uses a gaussian kernel with a sigma that is modulated by the value of the mask. This is computational costly to compute since the mask values are updated during the iterations, meaning we need a different kernel for every mask pixel for every iteration.

Initially I tried approximating this by first filtering the image with a filter bank of varying gaussian kernels. Then during optimization, the input image pixel would use the quantized mask value to select an appropriate filter bank output pixel (high mask value -> lower channel).

This was done using the PyTorch variable gather/select_index functions. But it turns out that the gather and select_index functions in PyTorch are not differentiable by the indexes.

Instead, we just compute a perturbed image once, and then blend the image and the perturbed image using:

input_image = (1 - mask) * image + mask * perturbed_image

And it works well in practice.

The perturbed image here is the average of the gaussian and median blurred image, but this can really be changed to many other combinations (try it out and find something better!).

Also now gaussian noise with a sigma of 0.2 is added to the preprocssed image at each iteration, inspired by google's SmoothGradient.

Owner
Jacob Gildenblat
Machine learning / Computer Vision.
Jacob Gildenblat
Code for 2021 NeurIPS --- Towards Multi-Grained Explainability for Graph Neural Networks

ReFine: Multi-Grained Explainability for GNNs We are trying hard to update the code, but it may take a while to complete due to our tight schedule rec

Shirley (Ying-Xin) Wu 47 Dec 16, 2022
Implementation for Shape from Polarization for Complex Scenes in the Wild

sfp-wild Implementation for Shape from Polarization for Complex Scenes in the Wild project website | paper Code and dataset will be released soon. Int

Chenyang LEI 41 Dec 23, 2022
deep_image_prior_extension

Code for "Is Deep Image Prior in Need of a Good Education?" Project page: https://jleuschn.github.io/docs.educated_deep_image_prior/. Supplementary Ma

riccardo barbano 7 Jan 09, 2022
Fader Networks: Manipulating Images by Sliding Attributes - NIPS 2017

FaderNetworks PyTorch implementation of Fader Networks (NIPS 2017). Fader Networks can generate different realistic versions of images by modifying at

Facebook Research 753 Dec 23, 2022
Code for unmixing audio signals in four different stems "drums, bass, vocals, others". The code is adapted from "Jukebox: A Generative Model for Music"

Status: Archive (code is provided as-is, no updates expected) Disclaimer This code is a based on "Jukebox: A Generative Model for Music" Paper We adju

Wadhah Zai El Amri 24 Dec 29, 2022
Simple ONNX operation generator. Simple Operation Generator for ONNX.

sog4onnx Simple ONNX operation generator. Simple Operation Generator for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools Key concept V

Katsuya Hyodo 6 May 15, 2022
The official PyTorch code for NeurIPS 2021 ML4AD Paper, "Does Thermal data make the detection systems more reliable?"

MultiModal-Collaborative (MMC) Learning Framework for integrating RGB and Thermal spectral modalities This is the official code for NeurIPS 2021 Machi

NeurAI 12 Nov 02, 2022
Official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Recognition" in AAAI2022.

AimCLR This is an official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Reco

Gty 44 Dec 17, 2022
Proof of concept GnuCash Webinterface

Proof of Concept GnuCash Webinterface This may one day be a something truly great. Milestones [ ] Browse accounts and view transactions [ ] Record sim

Josh 14 Dec 28, 2022
implementation for paper "ShelfNet for fast semantic segmentation"

ShelfNet-lightweight for paper (ShelfNet for fast semantic segmentation) This repo contains implementation of ShelfNet-lightweight models for real-tim

Juntang Zhuang 252 Sep 16, 2022
Learning Spatio-Temporal Transformer for Visual Tracking

STARK The official implementation of the paper Learning Spatio-Temporal Transformer for Visual Tracking Hiring research interns for visual transformer

Multimedia Research 484 Dec 29, 2022
DyStyle: Dynamic Neural Network for Multi-Attribute-Conditioned Style Editing

DyStyle: Dynamic Neural Network for Multi-Attribute-Conditioned Style Editing Figure: Joint multi-attribute edits using DyStyle model. Great diversity

74 Dec 03, 2022
A Python framework for conversational search

Chatty Goose Multi-stage Conversational Passage Retrieval: An Approach to Fusing Term Importance Estimation and Neural Query Rewriting Installation Ma

Castorini 36 Oct 23, 2022
Urban mobility simulations with Python3, RLlib (Deep Reinforcement Learning) and Mesa (Agent-based modeling)

Deep Reinforcement Learning for Smart Cities Documentation RLlib: https://docs.ray.io/en/master/rllib.html Mesa: https://mesa.readthedocs.io/en/stable

1 May 15, 2022
Generating Band-Limited Adversarial Surfaces Using Neural Networks

Generating Band-Limited Adversarial Surfaces Using Neural Networks This is the official repository of the technical report that was published on arXiv

3 Jul 26, 2022
Official Pytorch implementation of ICLR 2018 paper Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge.

Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge: Official Pytorch implementation of ICLR 2018 paper Deep Learning for Phy

emmanuel 47 Nov 06, 2022
coldcuts is an R package to automatically generate and plot segmentation drawings in R

coldcuts coldcuts is an R package that allows you to draw and plot automatically segmentations from 3D voxel arrays. The name is inspired by one of It

2 Sep 03, 2022
[ICLR 2021, Spotlight] Large Scale Image Completion via Co-Modulated Generative Adversarial Networks

Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight) Demo | Paper [NEW!] Time to play with our interac

Shengyu Zhao 373 Jan 02, 2023
MIRACLE (Missing data Imputation Refinement And Causal LEarning)

MIRACLE (Missing data Imputation Refinement And Causal LEarning) Code Author: Trent Kyono This repository contains the code used for the "MIRACLE: Cau

van_der_Schaar \LAB 15 Dec 29, 2022
Large scale and asynchronous Hyperparameter Optimization at your fingertip.

Syne Tune This package provides state-of-the-art distributed hyperparameter optimizers (HPO) where trials can be evaluated with several backend option

Amazon Web Services - Labs 236 Jan 01, 2023