Code for "Optimizing risk-based breast cancer screening policies with reinforcement learning"

Related tags

Deep LearningTempo
Overview

Tempo: Optimizing risk-based breast cancer screening policies with reinforcement learning DOI

Introduction

This repository was used to develop Tempo, as described in: Optimizing risk-based breast cancer screening policies with reinforcement learning.

Screening programs must balance the benefits of early detection against the costs of over screening. Here, we introduce a novel reinforcement learning-based framework for personalized screening, Tempo, and demonstrate its efficacy in the context of breast cancer. We trained our risk-based screening policies on a large screening mammography dataset from Massachusetts General Hospital (MGH) USA and validated them on held-out patients from MGH, and on external datasets from Emory USA, Karolinska Sweden and Chang Gung Memorial Hospital (CGMH) Taiwan. Across all test sets, we found that a Tempo policy combined with an image-based AI risk model, Mirai [1] was significantly more efficient than current regimes used in clinical practice in terms of simulated early detection per screen frequency. Moreover, we showed that the same Tempo policy can be easily adapted to a wide range of possible screening preferences, allowing clinicians to select their desired early detection to screening cost trade-off without training new policies. Finally, we demonstrated Tempo policies based on AI-based risk models out performed Tempo policies based on less accurate clinical risk models. Altogether, our results show that pairing AI-based risk models with agile AI-designed screening policies has the potential to improve screening programs, advancing early detection while reducing over-screening.

This code base is meant to provide exact implementation details for the development of Tempo.

Aside on Software Depedencies

This code assumes python3.6 and a Linux environment. The package requirements can be install with pip:

pip install -r requirements.txt

Tempo-Mirai assumes access to Mirai risk assessments. Resources for using Mirai are shown here.

Method

method

Our full framework, named Tempo, is depicted above. As described above, we first train a risk progression neural network to predict future risk assessments given previous assessments. This model is then used to estimate patient risk at unobserved timepoints and it enables us to simulate risk-based screening policies. Next, we train our screening policy, which is implemented as a neural network, to maximize the reward (i.e combination of early detection and screening cost) on our retrospective training set. We train our screening policy to support all possible early detection vs screening cost trade-offs using envelope Q-learning [2], an RL algorithm designed to balance multiple objectives. The input of our screening policies is the patient's risk assessment, and desired weighting between rewards (i.e screening preference). The output of the policy is a recommendation for when to return for the next screen, ranging from six months to three years in the future, in multiples of six months. Our reward balances two contrasting aspects, one reflecting the imaging cost, i.e., the average mammograms a year recommended by the policy, and one modeling early detection benefit relative to the retrospective screening trajectory. Our early detection reward measures the time difference in months between each patient's recommended screening date, if it was after their last negative mammogram, and their actual diagnosis date. We evaluate screening policies by simulating their recommendations for heldout patients.

Training Risk progression models

We experimented with different learning rates, hidden sizes, numbers of layers and dropout, and chose the model that obtained the lowest validation KL divergence on the MGH validation set. Our final risk progression RNN had two layers, a hidden dimension size of 100, a dropout of 0.25, and was trained for 30 epochs with a learning rate of 1e-3 using the Adam optimizer.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/risk_progression/gru.json

Given a trained risk progression model, we can now estimate unobserved risk assessments auto-regressively. At each time step, the model takes as input the previous risk assessment, the prior hidden state, using the previous predicted assessment if the real one is not available, and predicts the risk assessment at the next time step.

Training Tempo Personalized Screening Policies

We implemented our personalized screening policy as multiple layer perceptron, which took as input a risk assessment and weighting between rewards and predicted the Q-value for each action, i.e follow up recommendation, across the rewards. This network was trained using Envelope Q-Learning [2]. We experimented with different numbers of layers, hidden dimension sizes, learning rates, dropouts, exploration epsilons, target network reset rates and weight decay rates.

To reproduce our grid search for our Mirai risk progression model, you can run:

python scripts/dispatcher.py --experiment_config_path configs/screening/neural.json

Data availability

All datasets were used under license to the respective hospital system for the current study and are not publicly available. To access the MGH dataset, investigators should reach out to C.L. to apply for an IRB approved research collaboration and obtain an appropriate Data Use Agreement. To access the Karolinska dataset, investigators should reach out to F.S. to apply for an approved research collaboration and sign a Data Use Agreement. To access the CGMH dataset, investigators should contact G.L. to apply for an IRB approved research collaboration. To access the Emory dataset, investigators should reach out to H.T to apply for an approved collaboration.

References

[1] Yala, Adam, et al. "Toward robust mammography-based models for breast cancer risk." Science Translational Medicine 13.578 (2021).

[2] Yang, Runzhe, Xingyuan Sun, and Karthik Narasimhan. "A generalized algorithm for multi-objective reinforcement learning and policy adaptation." arXiv preprint arXiv:1908.08342 (2019).

Citing Tempo

@article{yala2021optimizing,
  title={Optimizing risk-based breast cancer screening policies with reinforcement learning},
  author={Yala, Adam and Mikhael, Peter and Lehman, Constance and Lin, Gigin and Strand, Fredrik and Wang, Yung-Liang and Hughes, Kevin and Satuluru, Siddharth and Kim, Thomas and Banerjee, Imon and others},
  year={2021}
}
You might also like...
Opinionated code formatter, just like Python's black code formatter but for Beancount

beancount-black Opinionated code formatter, just like Python's black code formatter but for Beancount Try it out online here Features MIT licensed - b

a delightful machine learning tool that allows you to train, test and use models without writing code
a delightful machine learning tool that allows you to train, test and use models without writing code

igel A delightful machine learning tool that allows you to train/fit, test and use models without writing code Note I'm also working on a GUI desktop

Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Code for: https://berkeleyautomation.github.io/bags/

DeformableRavens Code for the paper Learning to Rearrange Deformable Cables, Fabrics, and Bags with Goal-Conditioned Transporter Networks. Here is the

Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166
Code for our method RePRI for Few-Shot Segmentation. Paper at http://arxiv.org/abs/2012.06166

Region Proportion Regularized Inference (RePRI) for Few-Shot Segmentation In this repo, we provide the code for our paper : "Few-Shot Segmentation Wit

Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Code for
Code for "Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search"

Contextual Non-Local Alignment over Full-Scale Representation for Text-Based Person Search This is an implementation for our paper Contextual Non-Loca

Releases(v1.0)
Owner
Adam Yala
PhD Candidate at MIT CSAIL
Adam Yala
Code accompanying "Adaptive Methods for Aggregated Domain Generalization"

Adaptive Methods for Aggregated Domain Generalization (AdaClust) Official Pytorch Implementation of Adaptive Methods for Aggregated Domain Generalizat

Xavier Thomas 15 Sep 20, 2022
Code for "The Intrinsic Dimension of Images and Its Impact on Learning" - ICLR 2021 Spotlight

dimensions Estimating the instrinsic dimensionality of image datasets Code for: The Intrinsic Dimensionaity of Images and Its Impact On Learning - Phi

Phil Pope 41 Dec 10, 2022
SMPLpix: Neural Avatars from 3D Human Models

subject0_validation_poses.mp4 Left: SMPL-X human mesh registered with SMPLify-X, middle: SMPLpix render, right: ground truth video. SMPLpix: Neural Av

Sergey Prokudin 292 Dec 30, 2022
A real world application of a Recurrent Neural Network on a binary classification of time series data

What is this This is a real world application of a Recurrent Neural Network on a binary classification of time series data. This project includes data

Josep Maria Salvia Hornos 2 Jan 30, 2022
Fast Soft Color Segmentation

Fast Soft Color Segmentation

3 Oct 29, 2022
Official tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”

Tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”.

3.7k Dec 31, 2022
Answering Open-Domain Questions of Varying Reasoning Steps from Text

This repository contains the authors' implementation of the Iterative Retriever, Reader, and Reranker (IRRR) model in the EMNLP 2021 paper "Answering Open-Domain Questions of Varying Reasoning Steps

26 Dec 22, 2022
Digan - Official PyTorch implementation of Generating Videos with Dynamics-aware Implicit Generative Adversarial Networks

DIGAN (ICLR 2022) Official PyTorch implementation of "Generating Videos with Dyn

Sihyun Yu 147 Dec 31, 2022
A synthetic texture-invariant dataset for object detection of UAVs

A synthetic dataset for object detection of UAVs This repository contains a synthetic datasets accompanying the paper Sim2Air - Synthetic aerial datas

LARICS Lab 10 Aug 13, 2022
GAN Image Generator and Characterwise Image Recognizer with python

MODEL SUMMARY 모델의 구조는 크게 6단계로 나뉩니다. STEP 0: Input Image Predict 할 이미지를 모델에 입력합니다. STEP 1: Make Black and White Image STEP 1 은 입력받은 이미지의 글자를 흑색으로, 배경을

Juwan HAN 1 Feb 09, 2022
A Transformer-Based Siamese Network for Change Detection

ChangeFormer: A Transformer-Based Siamese Network for Change Detection (Under review at IGARSS-2022) Wele Gedara Chaminda Bandara, Vishal M. Patel Her

Wele Gedara Chaminda Bandara 214 Dec 29, 2022
Utility tools for the "Divide and Remaster" dataset, introduced as part of the Cocktail Fork problem paper

Divide and Remaster Utility Tools Utility tools for the "Divide and Remaster" dataset, introduced as part of the Cocktail Fork problem paper The DnR d

Darius Petermann 46 Dec 11, 2022
Pyramid Grafting Network for One-Stage High Resolution Saliency Detection. CVPR 2022

PGNet Pyramid Grafting Network for One-Stage High Resolution Saliency Detection. CVPR 2022, CVPR 2022 (arXiv 2204.05041) Abstract Recent salient objec

CVTEAM 109 Dec 05, 2022
Code of Periodic Activation Functions Induce Stationarity

Periodic Activation Functions Induce Stationarity This repository is the official implementation of the methods in the publication: L. Meronen, M. Tra

AaltoML 12 Jun 07, 2022
Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

7 Jun 22, 2022
Spectralformer: Rethinking hyperspectral image classification with transformers

The code in this toolbox implements the "Spectralformer: Rethinking hyperspectral image classification with transformers". More specifically, it is detailed as follow.

Danfeng Hong 104 Jan 04, 2023
Interpretable-contrastive-word-mover-s-embedding

Interpretable-contrastive-word-mover-s-embedding Paper Datasets Here is a Dropbox link to the datasets used in the paper: https://www.dropbox.com/sh/n

0 Nov 02, 2021
Weakly Supervised Text-to-SQL Parsing through Question Decomposition

Weakly Supervised Text-to-SQL Parsing through Question Decomposition The official repository for the paper "Weakly Supervised Text-to-SQL Parsing thro

14 Dec 19, 2022
Joint Discriminative and Generative Learning for Person Re-identification. CVPR'19 (Oral)

Joint Discriminative and Generative Learning for Person Re-identification [Project] [Paper] [YouTube] [Bilibili] [Poster] [Supp] Joint Discriminative

NVIDIA Research Projects 1.2k Dec 30, 2022
Machine learning notebooks in different subjects optimized to run in google collaboratory

Notebooks Name Description Category Link Training pix2pix This notebook shows a simple pipeline for training pix2pix on a simple dataset. Most of the

Zaid Alyafeai 363 Dec 06, 2022