Contextual Attention Network: Transformer Meets U-Net

Overview

Contextual Attention Network: Transformer Meets U-Net

Contexual attention network for medical image segmentation with state of the art results on skin lesion segmentation, multiple myeloma cell segmentation. This method incorpotrates the transformer module into a U-Net structure so as to concomitantly capture long-range dependency along with resplendent local informations. If this code helps with your research please consider citing the following paper:

R. Azad, Moein Heidari, Yuli Wu and Dorit Merhof , "Contextual Attention Network: Transformer Meets U-Net", download link.

@article{reza2022contextual,
  title={Contextual Attention Network: Transformer Meets U-Net},
  author={Reza, Azad and Moein, Heidari and Yuli, Wu and Dorit, Merhof},
  journal={arXiv preprint arXiv:2203.01932},
  year={2022}
}

Please consider starring us, if you found it useful. Thanks

Updates

This code has been implemented in python language using Pytorch library and tested in ubuntu OS, though should be compatible with related environment. following Environement and Library needed to run the code:

  • Python 3
  • Pytorch

Run Demo

For training deep model and evaluating on each data set follow the bellow steps:
1- Download the ISIC 2018 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18.
2- Run Prepare_ISIC2018.py for data preperation and dividing data to train,validation and test sets.
3- Run train_skin.py for training the model using trainng and validation sets. The model will be train for 100 epochs and it will save the best weights for the valiation set.
4- For performance calculation and producing segmentation result, run evaluate_skin.py. It will represent performance measures and will saves related results in results folder.

Notice: For training and evaluating on ISIC 2017 and ph2 follow the bellow steps :

ISIC 2017- Download the ISIC 2017 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18\7.
then Run Prepare_ISIC2017.py for data preperation and dividing data to train,validation and test sets.
ph2- Download the ph2 dataset from this link and extract it then Run Prepare_ph2.py for data preperation and dividing data to train,validation and test sets.
Follow step 3 and 4 for model traing and performance estimation. For ph2 dataset you need to first train the model with ISIC 2017 data set and then fine-tune the trained model using ph2 dataset.

Quick Overview

Diagram of the proposed method

Perceptual visualization of the proposed Contextual Attention module.

Diagram of the proposed method

Results

For evaluating the performance of the proposed method, Two challenging task in medical image segmentaion has been considered. In bellow, results of the proposed approach illustrated.

Task 1: SKin Lesion Segmentation

Performance Comparision on SKin Lesion Segmentation

In order to compare the proposed method with state of the art appraoches on SKin Lesion Segmentation, we considered Drive dataset.

Methods (On ISIC 2017) Dice-Score Sensivity Specificaty Accuracy
Ronneberger and et. all U-net 0.8159 0.8172 0.9680 0.9164
Oktay et. all Attention U-net 0.8082 0.7998 0.9776 0.9145
Lei et. all DAGAN 0.8425 0.8363 0.9716 0.9304
Chen et. all TransU-net 0.8123 0.8263 0.9577 0.9207
Asadi et. all MCGU-Net 0.8927 0.8502 0.9855 0.9570
Valanarasu et. all MedT 0.8037 0.8064 0.9546 0.9090
Wu et. all FAT-Net 0.8500 0.8392 0.9725 0.9326
Azad et. all Proposed TMUnet 0.9164 0.9128 0.9789 0.9660

For more results on ISIC 2018 and PH2 dataset, please refer to the paper

SKin Lesion Segmentation segmentation result on test data

SKin Lesion Segmentation  result (a) Input images. (b) Ground truth. (c) U-net. (d) Gated Axial-Attention. (e) Proposed method without a contextual attention module and (f) Proposed method.

Multiple Myeloma Cell Segmentation

Performance Evalution on the Multiple Myeloma Cell Segmentation task

Methods mIOU
Frequency recalibration U-Net 0.9392
XLAB Insights 0.9360
DSC-IITISM 0.9356
Multi-scale attention deeplabv3+ 0.9065
U-Net 0.7665
Baseline 0.9172
Proposed 0.9395

Multiple Myeloma Cell Segmentation results

Multiple Myeloma Cell Segmentation result

Model weights

You can download the learned weights for each dataset in the following table.

Dataset Learned weights
ISIC 2018 TMUnet
ISIC 2017 TMUnet
Ph2 TMUnet

Query

All implementations are done by Reza Azad and Moein Heidari. For any query please contact us for more information.

rezazad68@gmail.com
moeinheidari7829@gmail.com
Owner
Reza Azad
Deep Learning and Computer Vision Researcher
Reza Azad
Computational Methods Course at UdeA. Forked and size reduced from:

Computational Methods for Physics & Astronomy Book version at: https://restrepo.github.io/ComputationalMethods by: Sebastian Bustamante 2014/2015 Dieg

Diego Restrepo 11 Sep 10, 2022
Image Lowpoly based on Centroid Voronoi Diagram via python-opencv and taichi

CVTLowpoly: Image Lowpoly via Centroid Voronoi Diagram Image Sharp Feature Extraction using Guide Filter's Local Linear Theory via opencv-python. The

Pupa 4 Jul 29, 2022
Bayesian Inference Tools in Python

BayesPy Bayesian Inference Tools in Python Our goal is, given the discrete outcomes of events, estimate the distribution of categories. Using gradient

Max Sklar 99 Dec 14, 2022
This project aims to segment 4 common retinal lesions from Fundus Images.

This project aims to segment 4 common retinal lesions from Fundus Images.

Husam Nujaim 1 Oct 10, 2021
[ICCV 2021] Deep Hough Voting for Robust Global Registration

Deep Hough Voting for Robust Global Registration, ICCV, 2021 Project Page | Paper | Video Deep Hough Voting for Robust Global Registration Junha Lee1,

57 Nov 28, 2022
PyTorch code for the ICCV'21 paper: "Always Be Dreaming: A New Approach for Class-Incremental Learning"

Always Be Dreaming: A New Approach for Data-Free Class-Incremental Learning PyTorch code for the ICCV 2021 paper: Always Be Dreaming: A New Approach f

49 Dec 21, 2022
Differentiable simulation for system identification and visuomotor control

gradsim gradSim: Differentiable simulation for system identification and visuomotor control gradSim is a unified differentiable rendering and multiphy

105 Dec 18, 2022
A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

CapsGNN ⠀⠀ A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019). Abstract The high-quality node embeddings learned from the Graph Neur

Benedek Rozemberczki 1.2k Jan 02, 2023
This is a model made out of Neural Network specifically a Convolutional Neural Network model

This is a model made out of Neural Network specifically a Convolutional Neural Network model. This was done with a pre-built dataset from the tensorflow and keras packages. There are other alternativ

9 Oct 18, 2022
U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI

U-Net for brain segmentation U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI based on a deep learning segmentation alg

562 Jan 02, 2023
Pytorch Implementation of Continual Learning With Filter Atom Swapping (ICLR'22 Spolight) Paper

Continual Learning With Filter Atom Swapping Pytorch Implementation of Continual Learning With Filter Atom Swapping (ICLR'22 Spolight) Paper If find t

11 Aug 29, 2022
Run Effective Large Batch Contrastive Learning on Limited Memory GPU

Gradient Cache Gradient Cache is a simple technique for unlimitedly scaling contrastive learning batch far beyond GPU memory constraint. This means tr

Luyu Gao 198 Dec 29, 2022
PyTorch implementation of Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.

ALiBi PyTorch implementation of Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation. Quickstart Clone this reposit

Jake Tae 4 Jul 27, 2022
code release for USENIX'22 paper `On the Security Risks of AutoML`

This project is a minimized runnable project cut from trojanzoo, which contains more datasets, models, attacks and defenses. This repo will not be mai

Ren Pang 5 Apr 19, 2022
ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers

ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers Official implementation of ViewFormer. ViewFormer is a NeRF-free neural rend

Jonáš Kulhánek 169 Dec 30, 2022
QR2Pass-project - A proof of concept for an alternative (passwordless) authentication system to a web server

QR2Pass This is a proof of concept for an alternative (passwordless) authenticat

4 Dec 09, 2022
This repository contains code for the paper "Disentangling Label Distribution for Long-tailed Visual Recognition", published at CVPR' 2021

Disentangling Label Distribution for Long-tailed Visual Recognition (CVPR 2021) Arxiv link Blog post This codebase is built on Causal Norm. Install co

Hyperconnect 85 Oct 18, 2022
Unsupervised clustering of high content screen samples

Microscopium Unsupervised clustering and dataset exploration for high content screens. See microscopium in action Public dataset BBBC021 from the Broa

60 Dec 05, 2022
Audio Domain Adaptation for Acoustic Scene Classification using Disentanglement Learning

Audio Domain Adaptation for Acoustic Scene Classification using Disentanglement Learning Reference Abeßer, J. & Müller, M. Towards Audio Domain Adapt

Jakob Abeßer 2 Jul 06, 2022
Visualize Camera's Pose Using Extrinsic Parameter by Plotting Pyramid Model on 3D Space

extrinsic2pyramid Visualize Camera's Pose Using Extrinsic Parameter by Plotting Pyramid Model on 3D Space Intro A very simple and straightforward modu

JEONG HYEONJIN 106 Dec 28, 2022