Official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"

Overview

[AAAI2022] UCTransNet

This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer" which is accepted at AAAI2022.

framework

We propose a Channel Transformer module (CTrans) and use it to replace the skip connections in original U-Net, thus we name it "U-CTrans-Net".

Requirements

Install from the requirements.txt using:

pip install -r requirements.txt

Usage

1. Data Preparation

1.1. GlaS and MoNuSeg Datasets

The original data can be downloaded in following links:

Then prepare the datasets in the following format for easy use of the code:

├── datasets
    ├── GlaS
    │   ├── Test_Folder
    │   │   ├── img
    │   │   └── labelcol
    │   ├── Train_Folder
    │   │   ├── img
    │   │   └── labelcol
    │   └── Val_Folder
    │       ├── img
    │       └── labelcol
    └── MoNuSeg
        ├── Test_Folder
        │   ├── img
        │   └── labelcol
        ├── Train_Folder
        │   ├── img
        │   └── labelcol
        └── Val_Folder
            ├── img
            └── labelcol

1.2. Synapse Dataset

The Synapse dataset we used is provided by TransUNet's authors. Please go to https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md for details.

2. Training

As mentioned in the paper, we introduce two strategies to optimize UCTransNet.

The first step is to change the settings in Config.py, all the configurations including learning rate, batch size and etc. are in it.

2.1 Jointly Training

We optimize the convolution parameters in U-Net and the CTrans parameters together with a single loss. Run:

python train_model.py

2.2 Pre-training

Our method just replaces the skip connections in U-Net, so the parameters in U-Net can be used as part of pretrained weights.

By first training a classical U-Net using /nets/UNet.py then using the pretrained weights to train the UCTransNet, CTrans module can get better initial features.

This strategy can improve the convergence speed and may improve the final segmentation performance in some cases.

3. Testing

3.1. Get Pre-trained Models

Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links:

3.2. Test the Model and Visualize the Segmentation Results

First, change the session name in Config.py as the training phase. Then run:

python test_model.py

You can get the Dice and IoU scores and the visualization results.

4. Reproducibility

In our code, we carefully set the random seed and set cudnn as 'deterministic' mode to eliminate the randomness. However, there still exsist some factors which may cause different training results, e.g., the cuda version, GPU types, the number of GPUs and etc. The GPU used in our experiments is NVIDIA A40 (48G) and the cuda version is 11.2.

Especially for multi-GPU cases, the upsampling operation has big problems with randomness. See https://pytorch.org/docs/stable/notes/randomness.html for more details.

When training, we suggest to train the model twice to verify wheather the randomness is eliminated. Because we use the early stopping strategy, the final performance may change significantly due to the randomness.

Reference

Citations

If this code is helpful for your study, please cite:

@misc{wang2021uctransnet,
      title={UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer}, 
      author={Haonan Wang and Peng Cao and Jiaqi Wang and Osmar R. Zaiane},
      year={2021},
      eprint={2109.04335},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Contact

Haonan Wang ([email protected])

Owner
Haonan Wang
Haonan Wang
Neon: an add-on for Lightbulb making it easier to handle component interactions

Neon Neon is an add-on for Lightbulb making it easier to handle component interactions. Installation pip install git+https://github.com/neonjonn/light

Neon Jonn 9 Apr 29, 2022
Reproducing code of hair style replacement method from Barbershorp.

Barbershorp Reproducing code of hair style replacement method from Barbershorp. Also reproduces II2S, an improved version of Image2StyleGAN. Requireme

1 Dec 24, 2021
Official PyTorch implementation of the paper: Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting.

Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting Official PyTorch implementation of the paper: Improving Graph Neural Net

Giorgos Bouritsas 58 Dec 31, 2022
This repository contains the code and models for the following paper.

DC-ShadowNet Introduction This is an implementation of the following paper DC-ShadowNet: Single-Image Hard and Soft Shadow Removal Using Unsupervised

AuAgCu 65 Dec 27, 2022
Large Scale Fine-Grained Categorization and Domain-Specific Transfer Learning. CVPR 2018

Large Scale Fine-Grained Categorization and Domain-Specific Transfer Learning Tensorflow code and models for the paper: Large Scale Fine-Grained Categ

Yin Cui 187 Oct 01, 2022
Patch SVDD for Image anomaly detection

Patch SVDD Patch SVDD for Image anomaly detection. Paper: https://arxiv.org/abs/2006.16067 (published in ACCV 2020). Original Code : https://github.co

Hong-Jeongmin 0 Dec 03, 2021
[CVPR 2022] Unsupervised Image-to-Image Translation with Generative Prior

GP-UNIT - Official PyTorch Implementation This repository provides the official PyTorch implementation for the following paper: Unsupervised Image-to-

Shuai Yang 125 Jan 03, 2023
House_prices_kaggle - Predict sales prices and practice feature engineering, RFs, and gradient boosting

House Prices - Advanced Regression Techniques Predicting House Prices with Machine Learning This project is build to enhance my knowledge about machin

Gurpreet Singh 1 Jan 01, 2022
CoaT: Co-Scale Conv-Attentional Image Transformers

CoaT: Co-Scale Conv-Attentional Image Transformers Introduction This repository contains the official code and pretrained models for CoaT: Co-Scale Co

mlpc-ucsd 191 Dec 03, 2022
[UNMAINTAINED] Automated machine learning for analytics & production

auto_ml Automated machine learning for production and analytics Installation pip install auto_ml Getting started from auto_ml import Predictor from au

Preston Parry 1.6k Jan 02, 2023
This is a JAX implementation of Neural Radiance Fields for learning purposes.

learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi

Alex Nichol 62 Dec 20, 2022
Swapping face using Face Mesh with TensorFlow Lite

Swapping face using Face Mesh with TensorFlow Lite

iwatake 17 Apr 26, 2022
Unified Instance and Knowledge Alignment Pretraining for Aspect-based Sentiment Analysis

Unified Instance and Knowledge Alignment Pretraining for Aspect-based Sentiment Analysis Requirements python 3.7 pytorch-gpu 1.7 numpy 1.19.4 pytorch_

12 Oct 29, 2022
[AAAI 2022] Sparse Structure Learning via Graph Neural Networks for Inductive Document Classification

Sparse Structure Learning via Graph Neural Networks for inductive document classification Make graph dataset create co-occurrence graph for datasets.

16 Dec 22, 2022
HAR-stacked-residual-bidir-LSTMs - Deep stacked residual bidirectional LSTMs for HAR

HAR-stacked-residual-bidir-LSTM The project is based on this repository which is presented as a tutorial. It consists of Human Activity Recognition (H

Guillaume Chevalier 287 Dec 27, 2022
A real world application of a Recurrent Neural Network on a binary classification of time series data

What is this This is a real world application of a Recurrent Neural Network on a binary classification of time series data. This project includes data

Josep Maria Salvia Hornos 2 Jan 30, 2022
Reference implementation for Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Diffusion Probabilistic Models This repository provides a reference implementation of the method described in the paper: Deep Unsupervised Learning us

Jascha Sohl-Dickstein 238 Jan 02, 2023
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
TuckER: Tensor Factorization for Knowledge Graph Completion

TuckER: Tensor Factorization for Knowledge Graph Completion This codebase contains PyTorch implementation of the paper: TuckER: Tensor Factorization f

Ivana Balazevic 296 Dec 06, 2022
Malmo Collaborative AI Challenge - Team Pig Catcher

The Malmo Collaborative AI Challenge - Team Pig Catcher Approach The challenge involves 2 agents who can either cooperate or defect. The optimal polic

Kai Arulkumaran 66 Jun 29, 2022