Official Pytorch Implementation of GraphiT

Related tags

Deep LearningGraphiT
Overview

GraphiT: Encoding Graph Structure in Transformers

This repository implements GraphiT, described in the following paper:

Grégoire Mialon*, Dexiong Chen*, Margot Selosse*, Julien Mairal. GraphiT: Encoding Graph Structure in Transformers.
*Equal contribution

Short Description about GraphiT

Figure from paper

GraphiT is an instance of transformers designed for graph-structured data. It takes as input a graph seen as a set of its node features, and integrates the graph structure via i) relative positional encoding using kernels on graphs and ii) encoding local substructures around each node, e.g, short paths, before adding it to the node features. GraphiT is able to outperform Graph Neural Networks in different graph classification and regression tasks, and offers promising visualization capabilities for domains where interpretability is important, e.g, in chemoinformatics.

Installation

Environment:

numpy=1.18.1
scipy=1.3.2
Cython=0.29.23
scikit-learn=0.22.1
matplotlib=3.4
networkx=2.5
python=3.7
pytorch=1.6
torch-geometric=1.7

The train folds and model weights for visualization are already provided at the correct location. Datasets will be downloaded via Pytorch geometric.

To begin with, run:

cd GraphiT
. s_env

To install GCKN, you also need to run:

make

Training GraphiT on graph classification and regression tasks

All our experimental scripts are in the folder experiments. So to start with, run cd experiments.

Classification

To train GraphiT on NCI1 with diffusion kernel, run:

python run_transformer_cv.py --dataset NCI1 --fold-idx 1 --pos-enc diffusion --beta 1.0

Here --fold-idx can be varied from 1 to 10 to train on a specified training fold. To test a selected model, just add the --test flag.

To include Laplacian positional encoding into input node features, run:

python run_transformer_cv.py --dataset NCI1 --fold-idx 1 --pos-enc diffusion --beta 1.0 --lappe --lap-dim 8

To include GCKN path features into input node features, run:

python run_transformer_gckn_cv.py --dataset NCI1 --fold-idx 1 --pos-enc diffusion --beta 1.0 --gckn-path 5

Regression

To train GraphiT on ZINC, run:

python run_transformer.py --pos-enc diffusion --beta 1.0

To include Laplacian positional encoding into input node features, run:

python run_transformer.py --pos-enc diffusion --beta 1.0 --lappe --lap-dim 8

To include GCKN path features into input node features, run:

python run_transformer_gckn.py --pos-enc diffusion --beta 1.0 --gckn-path 8

Visualizing attention scores

To visualize attention scores for GraphiT trained on Mutagenicity, run:

cd experiments
python visu_attention.py --idx-sample 10

To visualize Nitrothiopheneamide-methylbenzene, choose 10 as sample index. To visualize Aminofluoranthene, choose 2003 as sample index. If you want to test for other samples (i.e, other indexes), make sure that the model correctly predicts mutagenicity (class 0) for this sample.

Citation

To cite GraphiT, please use the following Bibtex snippet:

@misc{mialon2021graphit,
      title={GraphiT: Encoding Graph Structure in Transformers}, 
      author={Gr\'egoire Mialon and Dexiong Chen and Margot Selosse and Julien Mairal},
      year={2021},
      eprint={2106.05667},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
Owner
Inria Thoth
A joint team of Inria and Laboratoire Jean Kuntzmann, we design models capable of representing visual information at scale from minimal supervision.
Inria Thoth
Object-aware Contrastive Learning for Debiased Scene Representation

Object-aware Contrastive Learning Official PyTorch implementation of "Object-aware Contrastive Learning for Debiased Scene Representation" by Sangwoo

43 Dec 14, 2022
[CVPR'21] Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration

Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration This repository contains the implementation of our paper Locally Aware Pi

sfwang 70 Dec 19, 2022
PyTorch implementations of algorithms for density estimation

pytorch-flows A PyTorch implementations of Masked Autoregressive Flow and some other invertible transformations from Glow: Generative Flow with Invert

Ilya Kostrikov 546 Dec 05, 2022
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
Planar Prior Assisted PatchMatch Multi-View Stereo

ACMP [News] The code for ACMH is released!!! [News] The code for ACMM is released!!! About This repository contains the code for the paper Planar Prio

Qingshan Xu 127 Dec 31, 2022
A machine learning project which can detect and predict the skin disease through image recognition.

ML-Project-2021 A machine learning project which can detect and predict the skin disease through image recognition. The dataset used for this is the H

Debshishu Ghosh 1 Jan 13, 2022
AntiFuzz: Impeding Fuzzing Audits of Binary Executables

AntiFuzz: Impeding Fuzzing Audits of Binary Executables Get the paper here: https://www.usenix.org/system/files/sec19-guler.pdf Usage: The python scri

Chair for Sys­tems Se­cu­ri­ty 88 Dec 21, 2022
Given a 2D triangle mesh, we could randomly generate cloud points that fill in the triangle mesh

generate_cloud_points Given a 2D triangle mesh, we could randomly generate cloud points that fill in the triangle mesh. Run python disp_mesh.py Or you

Peng Yu 2 Dec 24, 2021
Artificial Intelligence search algorithm base on Pacman

Pacman Search Artificial Intelligence search algorithm base on Pacman Source The Pacman Projects by the University of California, Berkeley. Layouts Di

Day Fundora 6 Nov 17, 2022
🛰️ Awesome Satellite Imagery Datasets

Awesome Satellite Imagery Datasets List of aerial and satellite imagery datasets with annotations for computer vision and deep learning. Newest datase

Christoph Rieke 3k Jan 03, 2023
DropNAS: Grouped Operation Dropout for Differentiable Architecture Search

DropNAS: Grouped Operation Dropout for Differentiable Architecture Search DropNAS, a grouped operation dropout method for one-level DARTS, with better

weijunhong 4 Aug 15, 2022
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Manling Li 49 Nov 21, 2022
QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

152 Jan 02, 2023
Pytorch and Keras Implementations of Hyperspectral Image Classification -- Traditional to Deep Models: A Survey for Future Prospects.

The repository contains the implementations for Hyperspectral Image Classification -- Traditional to Deep Models: A Survey for Future Prospects. Model

Ankur Deria 115 Jan 06, 2023
Real-Time-Student-Attendence-System - Real Time Student Attendence System

Real-Time-Student-Attendence-System The Student Attendance Management System Pro

Rounak Das 1 Feb 15, 2022
A object detecting neural network powered by the yolo architecture and leveraging the PyTorch framework and associated libraries.

Yolo-Powered-Detector A object detecting neural network powered by the yolo architecture and leveraging the PyTorch framework and associated libraries

Luke Wilson 1 Dec 03, 2021
Self-Supervised Deep Blind Video Super-Resolution

Self-Blind-VSR Paper | Discussion Self-Supervised Deep Blind Video Super-Resolution By Haoran Bai and Jinshan Pan Abstract Existing deep learning-base

Haoran Bai 35 Dec 09, 2022
PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Out-of-distribution Generalization Investigation on Vision Transformers This repository contains PyTorch evaluation code for Delving Deep into the Gen

Chongzhi Zhang 72 Dec 13, 2022
PSANet: Point-wise Spatial Attention Network for Scene Parsing, ECCV2018.

PSANet: Point-wise Spatial Attention Network for Scene Parsing (in construction) by Hengshuang Zhao*, Yi Zhang*, Shu Liu, Jianping Shi, Chen Change Lo

Hengshuang Zhao 217 Oct 30, 2022
Lightweight Cuda Renderer with Python Wrapper.

pyRender Lightweight Cuda Renderer with Python Wrapper. Compile Change compile.sh line 5 to the glm library include path. This library can be download

Jingwei Huang 53 Dec 02, 2022