Explainability for Vision Transformers (in PyTorch)

Overview

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path  --head_fusion  --discard_ratio  --category_index 

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

Owner
Jacob Gildenblat
Machine learning / Computer Vision developer.
Jacob Gildenblat
CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image.

CoReNet CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image. It produces coherent reconstructions, where all objec

Google Research 80 Dec 25, 2022
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Easy-to-use,Modular and Extendible package of deep-learning based CTR models .

DeepCTR DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models along with lots of core components layers which can

浅梦 6.6k Jan 08, 2023
This repository contains an implementation of the Permutohedral Attention Module in Pytorch

Permutohedral_attention_module This repository contains an implementation of the Permutohedral Attention Module

Samuel JOUTARD 26 Nov 27, 2022
PIXIE: Collaborative Regression of Expressive Bodies

PIXIE: Collaborative Regression of Expressive Bodies [Project Page] This is the official Pytorch implementation of PIXIE. PIXIE reconstructs an expres

Yao Feng 331 Jan 04, 2023
DecoupledNet is semantic segmentation system which using heterogeneous annotations

DecoupledNet: Decoupled Deep Neural Network for Semi-supervised Semantic Segmentation Created by Seunghoon Hong, Hyeonwoo Noh and Bohyung Han at POSTE

Hyeonwoo Noh 74 Sep 22, 2021
TLoL (Python Module) - League of Legends Deep Learning AI (Research and Development)

TLoL-py - League of Legends Deep Learning Library TLoL-py is the Python component of the TLoL League of Legends deep learning library. It provides a s

7 Nov 29, 2022
End-to-end speech secognition toolkit

End-to-end speech secognition toolkit This is an E2E ASR toolkit modified from Espnet1 (version 0.9.9). This is the official implementation of paper:

Jinchuan Tian 147 Dec 28, 2022
A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

2 Jul 25, 2022
TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning

TransZero++ This repository contains the testing code for the paper "TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning" submitted

Shiming Chen 6 Aug 16, 2022
Neural Dynamic Policies for End-to-End Sensorimotor Learning

This is a PyTorch based implementation for our NeurIPS 2020 paper on Neural Dynamic Policies for end-to-end sensorimotor learning.

Shikhar Bahl 47 Dec 11, 2022
An implementation of an abstract algebra for music tones (pitches).

nbdev template Use this template to more easily create your nbdev project. If you are using an older version of this template, and want to upgrade to

Open Music Kit 0 Oct 10, 2022
Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

Super Resolution Examples We run this script under TensorFlow 2.0 and the TensorLayer2.0+. For TensorLayer 1.4 version, please check release. 🚀 🚀 🚀

TensorLayer Community 2.9k Jan 08, 2023
ExCon: Explanation-driven Supervised Contrastive Learning

ExCon: Explanation-driven Supervised Contrastive Learning Link to the paper: https://arxiv.org/pdf/2111.14271.pdf Contributors of this repo: Zhibo Zha

Zhibo (Darren) Zhang 18 Nov 01, 2022
A learning-based data collection tool for human segmentation

FullBodyFilter A Learning-Based Data Collection Tool For Human Segmentation Contents Documentation Source Code and Scripts Overview of Project Usage O

Robert Jiang 4 Jun 24, 2022
TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular potentials

TorchMD-net TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular

TorchMD 104 Jan 03, 2023
Implementation of the final project of the course DDA6309 Probabilistic Graphical Model

Task-aware Joint CWS and POS (TCwsPos) This is the implementation of the final project of the course DDA6309 Probabilistic Graphical Models, The Chine

Peng 1 Dec 26, 2021
Towards Representation Learning for Atmospheric Dynamics (AtmoDist)

Towards Representation Learning for Atmospheric Dynamics (AtmoDist) The prediction of future climate scenarios under anthropogenic forcing is critical

Sebastian Hoffmann 4 Dec 15, 2022
Source code for GNN-LSPE (Graph Neural Networks with Learnable Structural and Positional Representations)

Graph Neural Networks with Learnable Structural and Positional Representations Source code for the paper "Graph Neural Networks with Learnable Structu

Vijay Prakash Dwivedi 180 Dec 22, 2022
Music library streaming app written in Flask & VueJS

djtaytay This is a little toy app made to explore Vue, brush up on my Python, and make a remote music collection accessable through a web interface. I

Ryan Tasson 6 May 27, 2022