Official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"

Overview

[AAAI2022] UCTransNet

This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer" which is accepted at AAAI2022.

framework

We propose a Channel Transformer module (CTrans) and use it to replace the skip connections in original U-Net, thus we name it "U-CTrans-Net".

Requirements

Install from the requirements.txt using:

pip install -r requirements.txt

Usage

1. Data Preparation

1.1. GlaS and MoNuSeg Datasets

The original data can be downloaded in following links:

Then prepare the datasets in the following format for easy use of the code:

├── datasets
    ├── GlaS
    │   ├── Test_Folder
    │   │   ├── img
    │   │   └── labelcol
    │   ├── Train_Folder
    │   │   ├── img
    │   │   └── labelcol
    │   └── Val_Folder
    │       ├── img
    │       └── labelcol
    └── MoNuSeg
        ├── Test_Folder
        │   ├── img
        │   └── labelcol
        ├── Train_Folder
        │   ├── img
        │   └── labelcol
        └── Val_Folder
            ├── img
            └── labelcol

1.2. Synapse Dataset

The Synapse dataset we used is provided by TransUNet's authors. Please go to https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md for details.

2. Training

As mentioned in the paper, we introduce two strategies to optimize UCTransNet.

The first step is to change the settings in Config.py, all the configurations including learning rate, batch size and etc. are in it.

2.1 Jointly Training

We optimize the convolution parameters in U-Net and the CTrans parameters together with a single loss. Run:

python train_model.py

2.2 Pre-training

Our method just replaces the skip connections in U-Net, so the parameters in U-Net can be used as part of pretrained weights.

By first training a classical U-Net using /nets/UNet.py then using the pretrained weights to train the UCTransNet, CTrans module can get better initial features.

This strategy can improve the convergence speed and may improve the final segmentation performance in some cases.

3. Testing

3.1. Get Pre-trained Models

Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links:

3.2. Test the Model and Visualize the Segmentation Results

First, change the session name in Config.py as the training phase. Then run:

python test_model.py

You can get the Dice and IoU scores and the visualization results.

4. Reproducibility

In our code, we carefully set the random seed and set cudnn as 'deterministic' mode to eliminate the randomness. However, there still exsist some factors which may cause different training results, e.g., the cuda version, GPU types, the number of GPUs and etc. The GPU used in our experiments is NVIDIA A40 (48G) and the cuda version is 11.2.

Especially for multi-GPU cases, the upsampling operation has big problems with randomness. See https://pytorch.org/docs/stable/notes/randomness.html for more details.

When training, we suggest to train the model twice to verify wheather the randomness is eliminated. Because we use the early stopping strategy, the final performance may change significantly due to the randomness.

Reference

Citations

If this code is helpful for your study, please cite:

@misc{wang2021uctransnet,
      title={UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer}, 
      author={Haonan Wang and Peng Cao and Jiaqi Wang and Osmar R. Zaiane},
      year={2021},
      eprint={2109.04335},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Contact

Haonan Wang ([email protected])

Owner
Haonan Wang
Haonan Wang
A fast poisson image editing implementation that can utilize multi-core CPU or GPU to handle a high-resolution image input.

Poisson Image Editing - A Parallel Implementation Jiayi Weng (jiayiwen), Zixu Chen (zixuc) Poisson Image Editing is a technique that can fuse two imag

Jiayi Weng 110 Dec 27, 2022
Gym environment for FLIPIT: The Game of "Stealthy Takeover"

gym-flipit Gym environment for FLIPIT: The Game of "Stealthy Takeover" invented by Marten van Dijk, Ari Juels, Alina Oprea, and Ronald L. Rivest. Desi

Lisa Oakley 2 Dec 15, 2021
Keras community contributions

keras-contrib : Keras community contributions Keras-contrib is deprecated. Use TensorFlow Addons. The future of Keras-contrib: We're migrating to tens

Keras 1.6k Dec 21, 2022
Official implementation of the MM'21 paper Constrained Graphic Layout Generation via Latent Optimization

[MM'21] Constrained Graphic Layout Generation via Latent Optimization This repository provides the official code for the paper "Constrained Graphic La

Kotaro Kikuchi 73 Dec 27, 2022
Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module

Invariant Point Attention - Pytorch Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alph

Phil Wang 113 Jan 05, 2023
PyTorch implementation of DeepLab v2 on COCO-Stuff / PASCAL VOC

DeepLab with PyTorch This is an unofficial PyTorch implementation of DeepLab v2 [1] with a ResNet-101 backbone. COCO-Stuff dataset [2] and PASCAL VOC

Kazuto Nakashima 995 Jan 08, 2023
Code repository for the work "Multi-Domain Incremental Learning for Semantic Segmentation", accepted at WACV 2022

Multi-Domain Incremental Learning for Semantic Segmentation This is the Pytorch implementation of our work "Multi-Domain Incremental Learning for Sema

Pgxo20 24 Jan 02, 2023
The official repository for "Revealing unforeseen diagnostic image features with deep learning by detecting cardiovascular diseases from apical four-chamber ultrasounds"

Revealing unforeseen diagnostic image features with deep learning by detecting cardiovascular diseases from apical four-chamber ultrasounds The why Im

3 Mar 29, 2022
Python TFLite scripts for detecting objects of any class in an image without knowing their label.

Python TFLite scripts for detecting objects of any class in an image without knowing their label.

Ibai Gorordo 42 Oct 07, 2022
Ansible Automation Example: JSNAPY PRE/POST Upgrade Validation

Ansible Automation Example: JSNAPY PRE/POST Upgrade Validation Overview This example will show how to validate the status of our firewall before and a

Calvin Remsburg 1 Jan 07, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 01, 2023
This repo contains the code required to train the multivariate time-series Transformer.

Multi-Variate Time-Series Transformer This repo contains the code required to train the multivariate time-series Transformer. Download the data The No

Gregory Duthé 4 Nov 24, 2022
An official PyTorch implementation of the TKDE paper "Self-Supervised Graph Representation Learning via Topology Transformations".

Self-Supervised Graph Representation Learning via Topology Transformations This repository is the official PyTorch implementation of the following pap

Hsiang Gao 2 Oct 31, 2022
A deep learning object detector framework written in Python for supporting Land Search and Rescue Missions.

AIR: Aerial Inspection RetinaNet for supporting Land Search and Rescue Missions AIR is a deep learning based object detection solution to automate the

Accenture 13 Dec 22, 2022
Weighing Counts: Sequential Crowd Counting by Reinforcement Learning

LibraNet This repository includes the official implementation of LibraNet for crowd counting, presented in our paper: Weighing Counts: Sequential Crow

Hao Lu 18 Nov 05, 2022
PolyGlot, a fuzzing framework for language processors

PolyGlot, a fuzzing framework for language processors Build We tested PolyGlot on Ubuntu 18.04. Get the source code: git clone https://github.com/s3te

Software Systems Security Team at Penn State University 79 Dec 27, 2022
This game was designed to encourage young people not to gamble on lotteries, as the probablity of correctly guessing the number is infinitesimal!

Lottery Simulator 2022 for Web Launch Application Developed by John Seong in Ontario. This game was designed to encourage young people not to gamble o

John Seong 2 Sep 02, 2022
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
[Link]deep_portfolo - Use Reforcemet earg ad Supervsed learg to Optmze portfolo allocato []

rl_portfolio This Repository uses Reinforcement Learning and Supervised learning to Optimize portfolio allocation. The goal is to make profitable agen

Deepender Singla 165 Dec 02, 2022
Fuzzy Overclustering (FOC)

Fuzzy Overclustering (FOC) In real-world datasets, we need consistent annotations between annotators to give a certain ground-truth label. However, in

2 Nov 08, 2022