Learning from graph data using Keras

Overview

Steps to run =>

  • Download the cora dataset from this link : https://linqs.soe.ucsc.edu/data
  • unzip the files in the folder input/cora
  • cd code
  • python eda.py
  • python word_features_only.py # for baseline model 53.28% accuracy
  • python graph_embedding.py # for model_1 73.06% accuracy
  • python graph_features_embedding.py # for model_2 76.35% accuracy

Learning from Graph data using Keras and Tensorflow

Cora Data set Citation Graph

Motivation :

There is a lot of data out there that can be represented in the form of a graph in real-world applications like in Citation Networks, Social Networks (Followers graph, Friends network, … ), Biological Networks or Telecommunications.
Using Graph extracted features can boost the performance of predictive models by relying of information flow between close nodes. However, representing graph data is not straightforward especially if we don’t intend to implement hand-crafted features.
In this post we will explore some ways to deal with generic graphs to do node classification based on graph representations learned directly from data.

Dataset :

The Cora citation network data set will serve as the base to the implementations and experiments throughout this post. Each node represents a scientific paper and edges between nodes represent a citation relation between the two papers.
Each node is represented by a set of binary features ( Bag of words ) as well as by a set of edges that link it to other nodes.
The dataset has 2708 nodes classified into one of seven classes. The network has 5429 links. Each Node is also represented by a binary word features indicating the presence of the corresponding word. Overall there is 1433 binary (Sparse) features for each node. In what follows we only use 140 samples for training and the rest for validation/test.

Problem Setting :

Problem : Assigning a class label to nodes in a graph while having few training samples.
Intuition/Hypothesis : Nodes that are close in the graph are more likely to have similar labels.
Solution : Find a way to extract features from the graph to help classify new nodes.

Proposed Approach :


Baseline Model :

Simple Baseline Model

We first experiment with the simplest model that learn to predict node classes using only the binary features and discarding all graph information.
This model is a fully-connected Neural Network that takes as input the binary features and outputs the class probabilities for each node.

Baseline model Accuracy : 53.28%

****This is the initial accuracy that we will try to improve on by adding graph based features.

Adding Graph features :

One way to automatically learn graph features by embedding each node into a vector by training a network on the auxiliary task of predicting the inverse of the shortest path length between two input nodes like detailed on the figure and code snippet below :

Learning an embedding vector for each node

The next step is to use the pre-trained node embedding as input to the classification model. We also add the an additional input which is the average binary features of the neighboring nodes using distance of learned embedding vectors.

The resulting classification network is described in the following figure :

Using pretrained embeddings to do node classification

Graph embedding classification model Accuracy : 73.06%

We can see that adding learned graph features as input to the classification model helps significantly improve the classification accuracy compared to the baseline model from **53.28% to 73.06% ** 😄 .

Improving Graph feature learning :

We can look to further improve the previous model by pushing the pre-training further and using the binary features in the node embedding network and reusing the pre-trained weights from the binary features in addition to the node embedding vector. This results in a model that relies on more useful representations of the binary features learned from the graph structure.

Improved Graph embedding classification model Accuracy : 76.35%

This additional improvement adds a few percent accuracy compared to the previous approach.

Conclusion :

In this post we saw that we can learn useful representations from graph structured data and then use these representations to improve the generalization performance of a node classification model from **53.28% to 76.35% ** 😎 .

Code to reproduce the results is available here : https://github.com/CVxTz/graph_classification

Owner
Mansar Youness
Mansar Youness
Edge-aware Guidance Fusion Network for RGB-Thermal Scene Parsing

EGFNet Edge-aware Guidance Fusion Network for RGB-Thermal Scene Parsing Dataset and Results Test maps: 百度网盘 提取码:zust Citation @ARTICLE{ author={Zhou,

ShaohuaDong 10 Dec 08, 2022
No-reference Image Quality Assessment(NIQA) Algorithms (BRISQUE, NIQE, PIQE, RankIQA, MetaIQA)

No-Reference Image Quality Assessment Algorithms No-reference Image Quality Assessment(NIQA) is a task of evaluating an image without a reference imag

Dae-Young Song 26 Jan 04, 2023
《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Rethinking Spatial Dimensions of Vision Transformers Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper NAVER

NAVER AI 224 Dec 27, 2022
MetaShift: A Dataset of Datasets for Evaluating Contextual Distribution Shifts and Training Conflicts (ICLR 2022)

MetaShift: A Dataset of Datasets for Evaluating Distribution Shifts and Training Conflicts This repo provides the PyTorch source code of our paper: Me

88 Jan 04, 2023
Implementation of Convolutional LSTM in PyTorch.

ConvLSTM_pytorch This file contains the implementation of Convolutional LSTM in PyTorch made by me and DavideA. We started from this implementation an

Andrea Palazzi 1.3k Dec 29, 2022
Knowledgeable Prompt-tuning: Incorporating Knowledge into Prompt Verbalizer for Text Classification

Knowledgeable Prompt-tuning: Incorporating Knowledge into Prompt Verbalizer for Text Classification

DingDing 143 Jan 01, 2023
Axel - 3D printed robotic hands and they controll with Raspberry Pi and Arduino combo

Axel It's our graduation project about 3D printed robotic hands and they control

0 Feb 14, 2022
Implementation of the Remixer Block from the Remixer paper, in Pytorch

Remixer - Pytorch Implementation of the Remixer Block from the Remixer paper, in Pytorch. It claims that substituting the feedforwards in transformers

Phil Wang 35 Aug 23, 2022
A repository for storing njxzc final exam review material

文档地址,请戳我 👈 👈 👈 ☀️ 1.Reason 大三上期末复习软件工程的时候,发现其他高校在GitHub上开源了他们学校的期末试题,我很受触动。期末

GuJiakai 2 Jan 18, 2022
Good Classification Measures and How to Find Them

Good Classification Measures and How to Find Them This repository contains supplementary materials for the paper "Good Classification Measures and How

Yandex Research 7 Nov 13, 2022
PyTorch module to use OpenFace's nn4.small2.v1.t7 model

OpenFace for Pytorch Disclaimer: This codes require the input face-images that are aligned and cropped in the same way of the original OpenFace. * I m

Pete Tae-hoon Kim 176 Dec 12, 2022
dualPC.R contains the R code for the main functions.

dualPC.R contains the R code for the main functions. dualPC_sim.R contains an example run with the different PC versions; it calls dualPC_algs.R whic

3 May 30, 2022
ViDT: An Efficient and Effective Fully Transformer-based Object Detector

ViDT: An Efficient and Effective Fully Transformer-based Object Detector by Hwanjun Song1, Deqing Sun2, Sanghyuk Chun1, Varun Jampani2, Dongyoon Han1,

NAVER AI 262 Dec 27, 2022
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

ELECTRA Introduction ELECTRA is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using

Google Research 2.1k Dec 28, 2022
CIFAR-10 Photo Classification

Image-Classification CIFAR-10 Photo Classification CIFAR-10_Dataset_Classfication CIFAR-10 Photo Classification Dataset CIFAR is an acronym that stand

ADITYA SHAH 1 Jan 05, 2022
Medical Insurance Cost Prediction using Machine earning

Medical-Insurance-Cost-Prediction-using-Machine-learning - Here in this project, I will use regression analysis to predict medical insurance cost for people in different regions, and based on several

1 Dec 27, 2021
Paddle implementation for "Cross-Lingual Word Embedding Refinement by ℓ1 Norm Optimisation" (NAACL 2021)

L1-Refinement Paddle implementation for "Cross-Lingual Word Embedding Refinement by ℓ1 Norm Optimisation" (NAACL 2021) 🙈 A more detailed readme is co

Lincedo Lab 4 Jun 09, 2021
This is Official implementation for "Pose-guided Feature Disentangling for Occluded Person Re-Identification Based on Transformer" in AAAI2022

PFD:Pose-guided Feature Disentangling for Occluded Person Re-identification based on Transformer This repo is the official implementation of "Pose-gui

Tao Wang 93 Dec 18, 2022
Politecnico of Turin Thesis: "Implementation and Evaluation of an Educational Chatbot based on NLP Techniques"

THESIS_CAIRONE_FIORENTINO Politecnico of Turin Thesis: "Implementation and Evaluation of an Educational Chatbot based on NLP Techniques" GENERATE TOKE

cairone_fiorentino97 1 Dec 10, 2021
This code is 3d-CNN model that can predict environmental value

Predict-environmental-value-3dCNN This code is 3d-CNN model that can predict environmental value. Firstly, I built a model that can create a lot of bu

1 Jan 06, 2022