This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Overview

Predicting Patient Outcomes with Graph Representation Learning

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning. You can watch a video of the spotlight talk at W3PHIAI (AAAI workshop) here:

Watch the video

Citation

If you use this code or the models in your research, please cite the following:

@misc{rocheteautong2021,
      title={Predicting Patient Outcomes with Graph Representation Learning}, 
      author={Emma Rocheteau and Catherine Tong and Petar Veličković and Nicholas Lane and Pietro Liò},
      year={2021},
      eprint={2101.03940},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Motivation

Recent work on predicting patient outcomes in the Intensive Care Unit (ICU) has focused heavily on the physiological time series data, largely ignoring sparse data such as diagnoses and medications. When they are included, they are usually concatenated in the late stages of a model, which may struggle to learn from rarer disease patterns. Instead, we propose a strategy to exploit diagnoses as relational information by connecting similar patients in a graph. To this end, we propose LSTM-GNN for patient outcome prediction tasks: a hybrid model combining Long Short-Term Memory networks (LSTMs) for extracting temporal features and Graph Neural Networks (GNNs) for extracting the patient neighbourhood information. We demonstrate that LSTM-GNNs outperform the LSTM-only baseline on length of stay prediction tasks on the eICU database. More generally, our results indicate that exploiting information from neighbouring patient cases using graph neural networks is a promising research direction, yielding tangible returns in supervised learning performance on Electronic Health Records.

Pre-Processing Instructions

eICU Pre-Processing

  1. To run the sql files you must have the eICU database set up: https://physionet.org/content/eicu-crd/2.0/.

  2. Follow the instructions: https://eicu-crd.mit.edu/tutorials/install_eicu_locally/ to ensure the correct connection configuration.

  3. Replace the eICU_path in paths.json to a convenient location in your computer, and do the same for eICU_preprocessing/create_all_tables.sql using find and replace for '/Users/emmarocheteau/PycharmProjects/eICU-GNN-LSTM/eICU_data/'. Leave the extra '/' at the end.

  4. In your terminal, navigate to the project directory, then type the following commands:

    psql 'dbname=eicu user=eicu options=--search_path=eicu'
    

    Inside the psql console:

    \i eICU_preprocessing/create_all_tables.sql
    

    This step might take a couple of hours.

    To quit the psql console:

    \q
    
  5. Then run the pre-processing scripts in your terminal. This will need to run overnight:

    python3 -m eICU_preprocessing.run_all_preprocessing
    

Graph Construction

To make the graphs, you can use the following scripts:

This is to make most of the graphs that we use. You can alter the arguments given to this script.

python3 -m graph_construction.create_graph --freq_adjust --penalise_non_shared --k 3 --mode k_closest

Write the diagnosis strings into eICU_data folder:

python3 -m graph_construction.get_diagnosis_strings

Get the bert embeddings:

python3 -m graph_construction.bert

Create the graph from the bert embeddings:

python3 -m graph_construction.create_bert_graph --k 3 --mode k_closest

Alternatively, you can request to download our graphs using this link: https://drive.google.com/drive/folders/1yWNLhGOTPhu6mxJRjKCgKRJCJjuToBS4?usp=sharing

Training the ML Models

Before proceeding to training the ML models, do the following.

  1. Define data_dir, graph_dir, log_path and ray_dir in paths.json to convenient locations.

  2. Run the following to unpack the processed eICU data into mmap files for easy loading during training. The mmap files will be saved in data_dir.

    python3 -m src.dataloader.convert
    

The following commands train and evaluate the models introduced in our paper.

N.B.

  • The models are structured using pytorch-lightning. Graph neural networks and neighbourhood sampling are implemented using pytorch-geometric.

  • Our models assume a default graph which is made with k=3 under a k-closest scheme. If you wish to use other graphs, refer to read_graph_edge_list in src/dataloader/pyg_reader.py to add a reference handle to version2filename for your graph.

  • The default task is In-House-Mortality Prediction (ihm), add --task los to the command to perform the Length-of-Stay Prediction (los) task instead.

  • These commands use the best set of hyperparameters; To use other hyperparameters, remove --read_best from the command and refer to src/args.py.

a. LSTM-GNN

The following runs the training and evaluation for LSTM-GNN models. --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_lstmgnn --bilstm --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstmgnn_search --bilstm --ts_mask --add_flat --class_weights  --gnn_name gat --add_diag

b. Dynamic LSTM-GNN

The following runs the training & evaluation for dynamic LSTM-GNN models. --gnn_name can be set as gcn, gat, or mpnn.

python3 -m train_dynamic --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.dynamic_lstmgnn_search --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn

c. GNN

The following runs the GNN models (with neighbourhood sampling). --gnn_name can be set as gat, sage, or mpnn. When mpnn is used, add --ns_sizes 10 to the command.

python3 -m train_ns_gnn --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.ns_gnn_search --ts_mask --add_flat --class_weights --gnn_name gat --add_diag

d. LSTM (Baselines)

The following runs the baseline bi-LSTMs. To remove diagnoses from the input vector, remove --add_diag from the command.

python3 -m train_ns_lstm --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag --read_best

The following runs a hyperparameter search.

python3 -m src.hyperparameters.lstm_search --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag
Owner
Emma Rocheteau
Computer Science PhD Student at Cambridge
Emma Rocheteau
A high-level Python library for Quantum Natural Language Processing

lambeq About lambeq is a toolkit for quantum natural language processing (QNLP). Documentation: https://cqcl.github.io/lambeq/ Getting started Prerequ

Cambridge Quantum 315 Jan 01, 2023
Resources for the Ki testnet challenge

Ki Testnet Challenge This repository hosts ki-testnet-challenge. A set of scripts and resources to be used for the Ki Testnet Challenge What is the te

Ki Foundation 23 Aug 08, 2022
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
Siamese TabNet

Raifhack-DS-2021 https://raifhack.ru/ - Команда Звёздочка Siamese TabNet Сиамская TabNet предсказывает стоимость объекта недвижимости с price_type=1,

Daniel Gafni 15 Apr 16, 2022
DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS) data.

DeepConsensus DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS)

Google 149 Dec 19, 2022
Spatio-Temporal Entropy Model (STEM) for end-to-end leaned video compression.

Spatio-Temporal Entropy Model A Pytorch Reproduction of Spatio-Temporal Entropy Model (STEM) for end-to-end leaned video compression. More details can

16 Nov 28, 2022
Implementation of "With a Little Help from my Temporal Context: Multimodal Egocentric Action Recognition, BMVC, 2021" in PyTorch

Multimodal Temporal Context Network (MTCN) This repository implements the model proposed in the paper: Evangelos Kazakos, Jaesung Huh, Arsha Nagrani,

Evangelos Kazakos 13 Nov 24, 2022
Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples / ICLR 2018

Training Confidence-Calibrated Classifier for Detecting Out-of-Distribution Samples This project is for the paper "Training Confidence-Calibrated Clas

168 Nov 29, 2022
Load What You Need: Smaller Multilingual Transformers for Pytorch and TensorFlow 2.0.

Smaller Multilingual Transformers This repository shares smaller versions of multilingual transformers that keep the same representations offered by t

Geotrend 79 Dec 28, 2022
Liquid Warping GAN with Attention: A Unified Framework for Human Image Synthesis

Liquid Warping GAN with Attention: A Unified Framework for Human Image Synthesis, including human motion imitation, appearance transfer, and novel view synthesis. Currently the paper is under review

2.3k Jan 05, 2023
Who calls the shots? Rethinking Few-Shot Learning for Audio (WASPAA 2021)

rethink-audio-fsl This repo contains the source code for the paper "Who calls the shots? Rethinking Few-Shot Learning for Audio." (WASPAA 2021) Table

Yu Wang 34 Dec 24, 2022
Credo AI Lens is a comprehensive assessment framework for AI systems. Lens standardizes model and data assessment, and acts as a central gateway to assessments created in the open source community.

Lens by Credo AI - Responsible AI Assessment Framework Lens is a comprehensive assessment framework for AI systems. Lens standardizes model and data a

Credo AI 27 Dec 14, 2022
Run object detection model on the Raspberry Pi

Using TensorFlow Lite with Python is great for embedded devices based on Linux, such as Raspberry Pi.

Dimitri Yanovsky 6 Oct 08, 2022
MoveNet Single Pose on DepthAI

MoveNet Single Pose tracking on DepthAI Running Google MoveNet Single Pose models on DepthAI hardware (OAK-1, OAK-D,...). A convolutional neural netwo

64 Dec 29, 2022
Deep learned, hardware-accelerated 3D object pose estimation

Isaac ROS Pose Estimation Overview This repository provides NVIDIA GPU-accelerated packages for 3D object pose estimation. Using a deep learned pose e

NVIDIA Isaac ROS 41 Dec 18, 2022
Repository containing the PhD Thesis "Formal Verification of Deep Reinforcement Learning Agents"

Getting Started This repository contains the code used for the following publications: Probabilistic Guarantees for Safe Deep Reinforcement Learning (

Edoardo Bacci 5 Aug 31, 2022
Hyper-parameter optimization for sklearn

hyperopt-sklearn Hyperopt-sklearn is Hyperopt-based model selection among machine learning algorithms in scikit-learn. See how to use hyperopt-sklearn

1.4k Jan 01, 2023
Awesome-google-colab - Google Colaboratory Notebooks and Repositories

Unofficial Google Colaboratory Notebook and Repository Gallery Please contact me to take over and revamp this repo (it gets around 30k views and 200k

Derek Snow 1.2k Jan 03, 2023
Source code for Acorn, the precision farming rover by Twisted Fields

Acorn precision farming rover This is the software repository for Acorn, the precision farming rover by Twisted Fields. For more information see twist

Twisted Fields 198 Jan 02, 2023
OcclusionFusion: realtime dynamic 3D reconstruction based on single-view RGB-D

OcclusionFusion (CVPR'2022) Project Page | Paper | Video Overview This repository contains the code for the CVPR 2022 paper OcclusionFusion, where we

Wenbin Lin 193 Dec 15, 2022