Explainability for Vision Transformers (in PyTorch)

Overview

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path  --head_fusion  --discard_ratio  --category_index 

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

Owner
Jacob Gildenblat
Machine learning / Computer Vision developer.
Jacob Gildenblat
Unofficial implementation of PatchCore anomaly detection

PatchCore anomaly detection Unofficial implementation of PatchCore(new SOTA) anomaly detection model Original Paper : Towards Total Recall in Industri

Changwoo Ha 268 Dec 22, 2022
"Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices", official implementation

Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices This repository contains the official PyTorch implemen

Yandex Research 21 Oct 18, 2022
:boar: :bear: Deep Learning based Python Library for Stock Market Prediction and Modelling

bulbea "Deep Learning based Python Library for Stock Market Prediction and Modelling." Table of Contents Installation Usage Documentation Dependencies

Achilles Rasquinha 1.8k Jan 05, 2023
This is an unofficial PyTorch implementation of Meta Pseudo Labels

This is an unofficial PyTorch implementation of Meta Pseudo Labels. The official Tensorflow implementation is here.

Jungdae Kim 320 Jan 08, 2023
🥇Samsung AI Challenge 2021 1등 솔루션입니다🥇

MoT - Molecular Transformer Large-scale Pretraining for Molecular Property Prediction Samsung AI Challenge for Scientific Discovery This repository is

Jungwoo Park 44 Dec 03, 2022
Convert onnx models to pytorch.

onnx2torch onnx2torch is an ONNX to PyTorch converter. Our converter: Is easy to use – Convert the ONNX model with the function call convert; Is easy

ENOT 264 Dec 30, 2022
Self-Supervised Vision Transformers Learn Visual Concepts in Histopathology (LMRL Workshop, NeurIPS 2021)

Self-Supervised Vision Transformers Learn Visual Concepts in Histopathology Self-Supervised Vision Transformers Learn Visual Concepts in Histopatholog

Richard Chen 95 Dec 24, 2022
BlockUnexpectedPackets - Preventing BungeeCord CPU overload due to Layer 7 DDoS attacks by scanning BungeeCord's logs

BlockUnexpectedPackets This script automatically blocks DDoS attacks that are sp

SparklyPower 3 Mar 31, 2022
PyTorch implementation of SimSiam: Exploring Simple Siamese Representation Learning

SimSiam: Exploring Simple Siamese Representation Learning This is a PyTorch implementation of the SimSiam paper: @Article{chen2020simsiam, author =

Facebook Research 834 Dec 30, 2022
A face dataset generator with out-of-focus blur detection and dynamic interval adjustment.

A face dataset generator with out-of-focus blur detection and dynamic interval adjustment.

Yutian Liu 2 Jan 29, 2022
Huawei Hackathon 2021 - Sweden (Stockholm)

huawei-hackathon-2021 Contributors DrakeAxelrod Challenge Requirements: python=3.8.10 Standard libraries (no importing) Important factors: Data depend

Drake Axelrod 32 Nov 08, 2022
As a part of the HAKE project, includes the reproduced SOTA models and the corresponding HAKE-enhanced versions (CVPR2020).

HAKE-Action HAKE-Action (TensorFlow) is a project to open the SOTA action understanding studies based on our Human Activity Knowledge Engine. It inclu

Yong-Lu Li 94 Nov 18, 2022
Semi-supervised Semantic Segmentation with Directional Context-aware Consistency (CVPR 2021)

Semi-supervised Semantic Segmentation with Directional Context-aware Consistency (CAC) Xin Lai*, Zhuotao Tian*, Li Jiang, Shu Liu, Hengshuang Zhao, Li

DV Lab 137 Dec 14, 2022
smc.covid is an R package related to the paper A sequential Monte Carlo approach to estimate a time varying reproduction number in infectious disease models: the COVID-19 case by Storvik et al

smc.covid smc.covid is an R package related to the paper A sequential Monte Carlo approach to estimate a time varying reproduction number in infectiou

0 Oct 15, 2021
Fully Convolutional DenseNet (A.K.A 100 layer tiramisu) for semantic segmentation of images implemented in TensorFlow.

FC-DenseNet-Tensorflow This is a re-implementation of the 100 layer tiramisu, technically a fully convolutional DenseNet, in TensorFlow (Tiramisu). Th

Hasnain Raza 121 Oct 12, 2022
Codes for "Template-free Prompt Tuning for Few-shot NER".

EntLM The source codes for EntLM. Dependencies: Cuda 10.1, python 3.6.5 To install the required packages by following commands: $ pip3 install -r requ

77 Dec 27, 2022
Multi-Glimpse Network With Python

Multi-Glimpse Network Our code requires Python ≥ 3.8 Installation For example, venv + pip: $ python3 -m venv env $ source env/bin/activate (env) $ pyt

9 May 10, 2022
This is the repository of shape matching algorithm Iterative Rotations and Assignments (IRA)

Description This is the repository of shape matching algorithm Iterative Rotations and Assignments (IRA), described in the publication [1]. Directory

MAMMASMIAS Consortium 6 Nov 14, 2022
Pure python implementations of popular ML algorithms.

Minimal ML algorithms This repo includes minimal implementations of popular ML algorithms using pure python and numpy. The purpose of these notebooks

Alexis Gidiotis 3 Jan 10, 2022
Equipped customers with insights about their EVs Hourly energy consumption and helped predict future charging behavior using LSTM model

Equipped customers with insights about their EVs Hourly energy consumption and helped predict future charging behavior using LSTM model. Designed sample dashboard with insights and recommendation for

Yash 2 Apr 07, 2022