Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Overview

Medical-Transformer

Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

About this repo:

This repo hosts the code for the following networks:

  1. Gated Axial Attention U-Net
  2. MedT

Introduction

Majority of existing Transformer-based network architectures proposed for vision applications require large-scale datasets to train properly. However, compared to the datasets for vision applications, for medical imaging the number of data samples is relatively low, making it difficult to efficiently train transformers for medical appli- cations. To this end, we propose a Gated Axial-Attention model which extends the existing architectures by introducing an additional control mechanism in the self-attention module. Furthermore, to train the model effectively on medical images, we propose a Local-Global training strat- egy (LoGo) which further improves the performance. Specifically, we op- erate on the whole image and patches to learn global and local features, respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net.

Using the code:

  • Clone this repository:
git clone https://github.com/jeya-maria-jose/Medical-Transformer
cd Medical-Transformer

The code is stable using Python 3.6.10, Pytorch 1.4.0

To install all the dependencies using conda:

conda env create -f environment.yml
conda activate medt

To install all the dependencies using pip:

pip install -r requirements.txt

Links for downloading the public Datasets:

  1. GLAS Dataset - Link (Original) | Link (Resized)
  2. MoNuSeG Dataset - Link (Original)
  3. Brain Anatomy US dataset from the paper will be made public soon !

Using the Code for your dataset

Dataset Preparation

Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format.

Train Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Validation Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Test Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
  • The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255.

Training Command:

python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no"
Change modelname to MedT or logo to train them

Testing Command:

python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "gatedaxialunet" --imgsize 128 --gray "no"

The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU.

Notes:

Note that these experiments were conducted in Nvidia Quadro 8000 with 48 GB memory.

Acknowledgement:

The dataloader code is inspired from pytorch-UNet . The axial attention code is developed from axial-deeplab.

Citation:

To add

Open an issue or mail me directly in case of any queries or suggestions.

Owner
Jeya Maria Jose
PhD Student at Johns Hopkins University.
Jeya Maria Jose
This project deploys a yolo fastest model in the form of tflite on raspberry 3b+. The model is from another repository of mine called -Trash-Classification-Car

Deploy-yolo-fastest-tflite-on-raspberry 觉得有用的话可以顺手点个star嗷 这个项目将垃圾分类小车中的tflite模型移植到了树莓派3b+上面。 该项目主要是为了记录在树莓派部署yolo fastest tflite的流程 (之后有时间会尝试用C++部署来提升

7 Aug 16, 2022
Replication Code for "Self-Supervised Bug Detection and Repair" NeurIPS 2021

Self-Supervised Bug Detection and Repair This is the reference code to replicate the research in Self-Supervised Bug Detection and Repair in NeurIPS 2

Microsoft 85 Dec 24, 2022
Code for CVPR2019 Towards Natural and Accurate Future Motion Prediction of Humans and Animals

Motion prediction with Hierarchical Motion Recurrent Network Introduction This work concerns motion prediction of articulate objects such as human, fi

Shuang Wu 85 Dec 11, 2022
Keras Model Implementation Walkthrough

Keras Model Implementation Walkthrough

Luke Wood 17 Sep 27, 2022
A lightweight library designed to accelerate the process of training PyTorch models by providing a minimal

A lightweight library designed to accelerate the process of training PyTorch models by providing a minimal, but extensible training loop which is flexible enough to handle the majority of use cases,

Chris Hughes 110 Dec 23, 2022
NCVX (NonConVeX): A User-Friendly and Scalable Package for Nonconvex Optimization in Machine Learning.

The source code is temporariy removed, as we are solving potential copyright and license issues with GRANSO (http://www.timmitchell.com/software/GRANS

SUN Group @ UMN 28 Aug 03, 2022
A modular, research-friendly framework for high-performance and inference of sequence models at many scales

T5X T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of

Google Research 1.1k Jan 08, 2023
PyTorch implementation of a Real-ESRGAN model trained on custom dataset

Real-ESRGAN PyTorch implementation of a Real-ESRGAN model trained on custom dataset. This model shows better results on faces compared to the original

Sber AI 160 Jan 04, 2023
A simple tutoral for error correction task, based on Pytorch

gramcorrector A simple tutoral for error correction task, based on Pytorch Grammatical Error Detection (sentence-level) a binary sequence-based classi

peiyuan_gong 8 Dec 03, 2022
Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization Official PyTorch implementation of the Fishr regularization for out-of-dist

62 Dec 22, 2022
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 865 Nov 17, 2022
Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network

Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network This repository is the official implementation of Speech Separati

Kai Li (李凯) 116 Nov 09, 2022
Automatic meme generation model using Tensorflow Keras.

Memefly You can find the project at MemeflyAI. Contributors Nick Buukhalter Harsh Desai Han Lee Project Overview Trello Board Product Canvas Automatic

BloomTech Labs 2 Jan 13, 2022
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha

Winnie Xu 95 Nov 26, 2021
3D position tracking for soccer players with multi-camera videos

This repo contains a full pipeline to support 3D position tracking of soccer players, with multi-view calibrated moving/fixed video sequences as inputs.

Yuchang Jiang 72 Dec 27, 2022
Manifold Alignment for Semantically Aligned Style Transfer

Manifold Alignment for Semantically Aligned Style Transfer [Paper] Getting Started MAST has been tested on CentOS 7.6 with python = 3.6. It supports

35 Nov 14, 2022
🔥🔥High-Performance Face Recognition Library on PaddlePaddle & PyTorch🔥🔥

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 04, 2023
NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

880 Jan 07, 2023
Implementation of parameterized soft-exponential activation function.

Soft-Exponential-Activation-Function: Implementation of parameterized soft-exponential activation function. In this implementation, the parameters are

Shuvrajeet Das 1 Feb 23, 2022
Robust Consistent Video Depth Estimation

[CVPR 2021] Robust Consistent Video Depth Estimation This repository contains Python and C++ implementation of Robust Consistent Video Depth, as descr

Facebook Research 213 Dec 17, 2022