[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Overview

Graph Stochastic Attention (GSAT)

The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism, to appear in ICML 2022.

Introduction

Commonly used attention mechanisms do not impose any constraints during training (besides normalization), and thus may lack interpretability. GSAT is a novel attention mechanism for building interpretable graph learning models. It injects stochasticity to learn attention, where a higher attention weight means a higher probability of the corresponding edge being kept during training. Such a mechanism will push the model to learn higher attention weights for edges that are important for prediction accuracy, which provides interpretability. To further improve the interpretability for graph learning tasks and avoid trivial solutions, we derive regularization terms for GSAT based on the information bottleneck (IB) principle. As a by-product, IB also helps model generalization. Fig. 1 shows the architecture of GSAT.

Figure 1. The architecture of GSAT.

Installation

We have tested our code on Python 3.9 with PyTorch 1.10.0, PyG 2.0.3 and CUDA 11.3. Please follow the following steps to create a virtual environment and install the required packages.

Create a virtual environment:

conda create --name gsat python=3.9
conda activate gsat

Install dependencies:

conda install -y pytorch==1.10.0 torchvision cudatoolkit=11.3 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
pip install -r requirements.txt

In case a lower CUDA version is required, please use the following command to install dependencies:

conda install -y pytorch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 cudatoolkit=10.2 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
pip install -r requirements.txt

Run Examples

We provide examples with minimal code to run GSAT in ./example/example.ipynb. We have tested the provided examples on Ba-2Motifs (GIN), Mutag (GIN) and OGBG-Molhiv (PNA). Yet, to implement GSAT* one needs to load a pre-trained model first in the provided example.

It should be able to run on other datasets as well, but some hard-coded hyperparameters might need to be changed accordingly. To reproduce results for other datasets, please follow the instructions in the following section.

Reproduce Results

We provide the source code to reproduce the results in our paper. The results of GSAT can be reproduced by running run_gsat.py. To reproduce GSAT*, one needs to run pretrain_clf.py first and change the configuration file accordingly (from_scratch: false).

To pre-train a classifier:

cd ./src
python pretrain_clf.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

To train GSAT:

cd ./src
python run_gsat.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

dataset_name can be choosen from ba_2motifs, mutag, mnist, Graph-SST2, spmotif_0.5, spmotif_0.7, spmotif_0.9, ogbg_molhiv, ogbg_moltox21, ogbg_molbace, ogbg_molbbbp, ogbg_molclintox, ogbg_molsider.

model_name can be choosen from GIN, PNA.

GPU_id is the id of the GPU to use. To use CPU, please set it to -1.

Training Logs

Standard output provides basic training logs, while more detailed logs and interpretation visualizations can be found on tensorboard:

tensorboard --logdir=./data/[dataset_name]/logs

Hyperparameter Settings

All settings can be found in ./src/configs.

Instructions on Acquiring Datasets

  • Ba_2Motifs

    • Raw data files can be downloaded automatically, provided by PGExplainer and DIG.
  • Spurious-Motif

    • Raw data files can be generated automatically, provide by DIR.
  • OGBG-Mol

    • Raw data files can be downloaded automatically, provided by OGBG.
  • Mutag

    • Raw data files need to be downloaded here, provided by PGExplainer.
    • Unzip Mutagenicity.zip and Mutagenicity.pkl.zip.
    • Put the raw data files in ./data/mutag/raw.
  • Graph-SST2

    • Raw data files need to be downloaded here, provided by DIG.
    • Unzip the downloaded Graph-SST2.zip.
    • Put the raw data files in ./data/Graph-SST2/raw.
  • MNIST-75sp

    • Raw data files need to be generated following the instruction here.
    • Put the generated files in ./data/mnist/raw.

FAQ

Does GSAT encourage sparsity?

No, GSAT doesn't encourage generating sparse subgraphs. We find r = 0.7 (Eq.(9) in our paper) can generally work well for all datasets in our experiments, which means during training roughly 70% of edges will be kept (kind of still large). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph, which is what previous works normally do and will hurt performance significantly for inhrently interpretable models (as shown in Fig. 7 in the paper). By contrast, GSAT provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.

How to choose the value of r?

A grid search in [0.5, 0.6, 0.7, 0.8, 0.9] is recommended, but r = 0.7 is a good starting point. Note that in practice we would decay the value of r gradually during training from 0.9 to the chosen value.

p or α to implement Eq.(9)?

Recall in Fig. 1, p is the probability of dropping an edge, while α is the sampled result from Bern(p). In our provided implementation, as an empirical choice, α is used to implement Eq.(9) (the Gumbel-softmax trick makes α essentially continuous in practice). We find that when α is used it may provide more regularization and makes the model more robust to hyperparameters. Nonetheless, using p can achieve the same performance, but it needs some more tuning.

Can you show an example of how GSAT works?

Below we show an example from the ba_2motifs dataset, which is to distinguish five-node cycle motifs (left) and house motifs (right). To make good predictions (minimize the cross-entropy loss), GSAT will push the attention weights of those critical edges to be relatively large (ideally close to 1). Otherwise, those critical edges may be dropped too frequently and thus result in a large cross-entropy loss. Meanwhile, to minimize the regularization loss (the KL divergence term in Eq.(9) of the paper), GSAT will push the attention weights of other non-critical edges to be close to r, which is set to be 0.7 in the example. This mechanism of injecting stochasticity makes the learned attention weights from GSAT directly interpretable, since the more critical an edge is, the larger its attention weight will be (the less likely it can be dropped). Note that ba_2motifs satisfies our Thm. 4.1 with no noise, and GSAT achieves perfect interpretation performance on it.

Figure 2. An example of the learned attention weights.

Reference

If you find our paper and repo useful, please cite our paper:

@article{miao2022interpretable,
  title={Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism},
  author={Miao, Siqi and Liu, Miaoyuan and Li, Pan},
  journal={arXiv preprint arXiv:2201.12987},
  year={2022}
}
Prompts - Read a textfile of prompts and import into anki via ankiconnect

prompts read a textfile of prompts and import into anki via ankiconnect Usage In

Alexander Cobleigh 2 Jul 28, 2022
Supervised domain-agnostic prediction framework for probabilistic modelling

A supervised domain-agnostic framework that allows for probabilistic modelling, namely the prediction of probability distributions for individual data

The Alan Turing Institute 112 Oct 23, 2022
A PyTorch implementation of the continual learning experiments with deep neural networks

Brain-Inspired Replay A PyTorch implementation of the continual learning experiments with deep neural networks described in the following paper: Brain

182 Dec 27, 2022
Plug and play transformer you can find network structure and official complete code by clicking List

Plug-and-play Module Plug and play transformer you can find network structure and official complete code by clicking List The following is to quickly

8 Mar 27, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
PyTorch META-DATASET (Few-shot classification benchmark)

PyTorch META-DATASET (Few-shot classification benchmark) This repo contains a PyTorch implementation of meta-dataset and a unified implementation of s

Malik Boudiaf 39 Oct 31, 2022
PyContinual (An Easy and Extendible Framework for Continual Learning)

PyContinual (An Easy and Extendible Framework for Continual Learning) Easy to Use You can sumply change the baseline, backbone and task, and then read

176 Jan 05, 2023
Open source code for the paper of Neural Sparse Voxel Fields.

Neural Sparse Voxel Fields (NSVF) Project Page | Video | Paper | Data Photo-realistic free-viewpoint rendering of real-world scenes using classical co

Meta Research 647 Dec 27, 2022
This repository for project that can Automate Number Plate Recognition (ANPR) in Morocco Licensed Vehicles. 💻 + 🚙 + 🇲🇦 = 🤖 🕵🏻‍♂️

MoroccoAI Data Challenge (Edition #001) This Reposotory is result of our work in the comepetiton organized by MoroccoAI in the context of the first Mo

SAFOINE EL KHABICH 14 Oct 31, 2022
Tzer: TVM Implementation of "Coverage-Guided Tensor Compiler Fuzzing with Joint IR-Pass Mutation (OOPSLA'22)“.

Artifact • Reproduce Bugs • Quick Start • Installation • Extend Tzer Coverage-Guided Tensor Compiler Fuzzing with Joint IR-Pass Mutation This is the s

12 Dec 29, 2022
Unsupervised Learning of Multi-Frame Optical Flow with Occlusions

This is a Pytorch implementation of Janai, J., Güney, F., Ranjan, A., Black, M. and Geiger, A., Unsupervised Learning of Multi-Frame Optical Flow with

Anurag Ranjan 110 Nov 02, 2022
Simulated garment dataset for virtual try-on

Simulated garment dataset for virtual try-on This repository contains the dataset used in the following papers: Self-Supervised Collision Handling via

33 Dec 20, 2022
Dynamica causal Bayesian optimisation

Dynamic Causal Bayesian Optimization This is a Python implementation of Dynamic Causal Bayesian Optimization as presented at NeurIPS 2021. Abstract Th

nd308 18 Nov 22, 2022
Generative Handwriting using LSTM Mixture Density Network with TensorFlow

Generative Handwriting Demo using TensorFlow An attempt to implement the random handwriting generation portion of Alex Graves' paper. See my blog post

hardmaru 686 Nov 24, 2022
Implementation of "JOKR: Joint Keypoint Representation for Unsupervised Cross-Domain Motion Retargeting"

JOKR: Joint Keypoint Representation for Unsupervised Cross-Domain Motion Retargeting Pytorch implementation for the paper "JOKR: Joint Keypoint Repres

45 Dec 25, 2022
Using deep learning to predict gene structures of the coding genes in DNA sequences of Arabidopsis thaliana

DeepGeneAnnotator: A tool to annotate the gene in the genome The master thesis of the "Using deep learning to predict gene structures of the coding ge

Ching-Tien Wang 3 Sep 09, 2022
DSAC* for Visual Camera Re-Localization (RGB or RGB-D)

DSAC* for Visual Camera Re-Localization (RGB or RGB-D) Introduction Installation Data Structure Supported Datasets 7Scenes 12Scenes Cambridge Landmark

Visual Learning Lab 143 Dec 22, 2022
A PyTorch implementation of "Signed Graph Convolutional Network" (ICDM 2018).

SGCN ⠀ A PyTorch implementation of Signed Graph Convolutional Network (ICDM 2018). Abstract Due to the fact much of today's data can be represented as

Benedek Rozemberczki 251 Nov 30, 2022
Official implementation for the paper "Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection"

Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection PyTorch code release of the paper "Attentive Prototypes for Sour

Deepti Hegde 23 Oct 17, 2022
Official Repository of NeurIPS2021 paper: PTR

PTR: A Benchmark for Part-based Conceptual, Relational, and Physical Reasoning Figure 1. Dataset Overview. Introduction A critical aspect of human vis

Yining Hong 32 Jun 02, 2022