Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Overview

Medical-Transformer

Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

About this repo:

This repo hosts the code for the following networks:

  1. Gated Axial Attention U-Net
  2. MedT

Introduction

Majority of existing Transformer-based network architectures proposed for vision applications require large-scale datasets to train properly. However, compared to the datasets for vision applications, for medical imaging the number of data samples is relatively low, making it difficult to efficiently train transformers for medical appli- cations. To this end, we propose a Gated Axial-Attention model which extends the existing architectures by introducing an additional control mechanism in the self-attention module. Furthermore, to train the model effectively on medical images, we propose a Local-Global training strat- egy (LoGo) which further improves the performance. Specifically, we op- erate on the whole image and patches to learn global and local features, respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net.

Using the code:

  • Clone this repository:
git clone https://github.com/jeya-maria-jose/Medical-Transformer
cd Medical-Transformer

The code is stable using Python 3.6.10, Pytorch 1.4.0

To install all the dependencies using conda:

conda env create -f environment.yml
conda activate medt

To install all the dependencies using pip:

pip install -r requirements.txt

Links for downloading the public Datasets:

  1. GLAS Dataset - Link (Original) | Link (Resized)
  2. MoNuSeG Dataset - Link (Original)
  3. Brain Anatomy US dataset from the paper will be made public soon !

Using the Code for your dataset

Dataset Preparation

Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format.

Train Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Validation Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
Test Folder-----
      img----
          0001.png
          0002.png
          .......
      label---
          0001.png
          0002.png
          .......
  • The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255.

Training Command:

python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no"
Change modelname to MedT or logo to train them

Testing Command:

python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "gatedaxialunet" --imgsize 128 --gray "no"

The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU.

Notes:

Note that these experiments were conducted in Nvidia Quadro 8000 with 48 GB memory.

Acknowledgement:

The dataloader code is inspired from pytorch-UNet . The axial attention code is developed from axial-deeplab.

Citation:

To add

Open an issue or mail me directly in case of any queries or suggestions.

Owner
Jeya Maria Jose
PhD Student at Johns Hopkins University.
Jeya Maria Jose
TiP-Adapter: Training-free CLIP-Adapter for Better Vision-Language Modeling

TiP-Adapter: Training-free CLIP-Adapter for Better Vision-Language Modeling This is the official code release for the paper 'TiP-Adapter: Training-fre

peng gao 189 Jan 04, 2023
Decentralized Reinforcment Learning: Global Decision-Making via Local Economic Transactions (ICML 2020)

Decentralized Reinforcement Learning This is the code complementing the paper Decentralized Reinforcment Learning: Global Decision-Making via Local Ec

40 Oct 30, 2022
Repository For Programmers Seeking a platform to show their skills

Programming-Nerds Repository For Programmers Seeking Pull Requests In hacktoberfest ❓ What's Hacktoberfest 2021? Hacktoberfest is the easiest way to g

42 Oct 29, 2022
ColossalAI-Benchmark - Performance benchmarking with ColossalAI

Benchmark for Tuning Accuracy and Efficiency Overview The benchmark includes our

HPC-AI Tech 31 Oct 07, 2022
Next-Best-View Estimation based on Deep Reinforcement Learning for Active Object Classification

next_best_view_rl Setup Clone the repository: git clone --recurse-submodules ... In 'third_party/zed-ros-wrapper': git checkout devel Install mujoco `

Christian Korbach 1 Feb 15, 2022
A Pytorch implementation of "Splitter: Learning Node Representations that Capture Multiple Social Contexts" (WWW 2019).

Splitter ⠀⠀ A PyTorch implementation of Splitter: Learning Node Representations that Capture Multiple Social Contexts (WWW 2019). Abstract Recent inte

Benedek Rozemberczki 201 Nov 09, 2022
Weakly Supervised Text-to-SQL Parsing through Question Decomposition

Weakly Supervised Text-to-SQL Parsing through Question Decomposition The official repository for the paper "Weakly Supervised Text-to-SQL Parsing thro

14 Dec 19, 2022
Relative Positional Encoding for Transformers with Linear Complexity

Stochastic Positional Encoding (SPE) This is the source code repository for the ICML 2021 paper Relative Positional Encoding for Transformers with Lin

Antoine Liutkus 48 Nov 16, 2022
Neural Logic Inductive Learning

Neural Logic Inductive Learning This is the implementation of the Neural Logic Inductive Learning model (NLIL) proposed in the ICLR 2020 paper: Learn

36 Nov 28, 2022
Release of the ConditionalQA dataset

ConditionalQA Datasets accompanying the paper ConditionalQA: A Complex Reading Comprehension Dataset with Conditional Answers. Disclaimer This dataset

14 Oct 17, 2022
[CVPR'21] Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration

Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration This repository contains the implementation of our paper Locally Aware Pi

sfwang 70 Dec 19, 2022
A flag generation AI created using DeepAIs API

Vex AI or Vexiology AI is an Artifical Intelligence created to generate custom made flag design texts. It uses DeepAIs API. Please be aware that you must include your own DeepAI API key. See instruct

Bernie 10 Apr 06, 2022
Tracking code for the winner of track 1 in the MMP-Tracking Challenge at ICCV 2021 Workshop.

Tracking Code for the winner of track1 in MMP-Trakcing challenge This repository contains our tracking code for the Multi-camera Multiple People Track

DamoCV 29 Nov 13, 2022
Can we learn gradients by Hamiltonian Neural Networks?

Can we learn gradients by Hamiltonian Neural Networks? This project was carried out as part of the Optimization for Machine Learning course (CS-439) a

2 Aug 22, 2022
SuperSDR: multiplatform KiwiSDR + CAT transceiver integrator

SuperSDR SuperSDR integrates a realtime spectrum waterfall and audio receive from any KiwiSDR around the world, together with a local (or remote) cont

Marco Cogoni 30 Nov 29, 2022
MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Resolution (CVPR2021)

MASA-SR Official PyTorch implementation of our CVPR2021 paper MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Re

DV Lab 126 Dec 20, 2022
Automate issue discovery for your projects against Lightning nightly and releases.

Automated Testing for Lightning EcoSystem Projects Automate issue discovery for your projects against Lightning nightly and releases. You get CPUs, Mu

Pytorch Lightning 41 Dec 24, 2022
Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations, CVPR 2019 (Oral)

Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations The code of: Weakly Supervised Learning of Instance Segmentation with I

Jiwoon Ahn 472 Dec 29, 2022
Pytorch implementation of DeepMind's differentiable neural computer paper.

DNC pytorch This is a Pytorch implementation of DeepMind's Differentiable Neural Computer (DNC) architecture introduced in their recent Nature paper:

Yuanpu Xie 91 Nov 21, 2022
Learning Continuous Signed Distance Functions for Shape Representation

DeepSDF This is an implementation of the CVPR '19 paper "DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" by Park et a

Meta Research 1.1k Jan 01, 2023