Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

Overview

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data

arXiv License: MIT

Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl

| Project Page | Paper | Poster | Slides | Video |

1

This repository includes the official and maintained PyTorch implementation of the paper OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data.

Abstract

Convolutional neural networks (CNNs) are the current state-of-the-art meta-algorithm for volumetric segmentation of medical data, for example, to localize COVID-19 infected tissue on computer tomography scans or the detection of tumour volumes in magnetic resonance imaging. A key limitation of 3D CNNs on voxelised data is that the memory consumption grows cubically with the training data resolution. Occupancy networks (O-Nets) are an alternative for which the data is represented continuously in a function space and 3D shapes are learned as a continuous decision boundary. While O-Nets are significantly more memory efficient than 3D CNNs, they are limited to simple shapes, are relatively slow at inference, and have not yet been adapted for 3D semantic segmentation of medical data. Here, we propose Occupancy Networks for Semantic Segmentation (OSS-Nets) to accurately and memory-efficiently segment 3D medical data. We build upon the original O-Net with modifications for increased expressiveness leading to improved segmentation performance comparable to 3D CNNs, as well as modifications for faster inference. We leverage local observations to represent complex shapes and prior encoder predictions to expedite inference. We showcase OSS-Net's performance on 3D brain tumour and liver segmentation against a function space baseline (O-Net), a performance baseline (3D residual U-Net), and an efficiency baseline (2D residual U-Net). OSS-Net yields segmentation results similar to the performance baseline and superior to the function space and efficiency baselines. In terms of memory efficiency, OSS-Net consumes comparable amounts of memory as the function space baseline, somewhat more memory than the efficiency baseline and significantly less than the performance baseline. As such, OSS-Net enables memory-efficient and accurate 3D semantic segmentation that can scale to high resolutions.

If you find this research useful in your work, please cite our paper:

@inproceedings{Reich2021,
        title={{OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data}},
        author={Reich, Christoph and Prangemeier, Tim and Cetin, {\"O}zdemir and Koeppl, Heinz},
        booktitle={British Machine Vision Conference},
        year={2021},
        organization={British Machine Vision Association},
}

Dependencies

All required Python packages can be installed by:

pip install -r requirements.txt

To install the official implementation of the Padé Activation Unit [1] (taken from the official repository) run:

cd pade_activation_unit/cuda
python setup.py build install

The code is tested with PyTorch 1.8.1 and CUDA 11.1 on Linux with Python 3.8.5! Using other PyTorch and CUDA versions newer than PyTorch 1.7.0 and CUDA 10.1 should also be possible.

Data

The BraTS 2020 dataset can be downloaded here and the LiTS dataset can be downloaded here. Please note, that accounts are required to login and downlaod the data on both websites.

The used training and validation split of the BraTS 2020 dataset is available here.

For generating the border maps, necessary if border based sampling is utilized, please use the generate_borders_bra_ts_2020.py and generate_borders_lits.py script.

Trained Models

Table 1. Segmentation results of trained networks. Weights are generally available here and specific models are linked below.

Model Dice () BraTS 2020 IoU () BraTS 2020 Dice () LiTS IoU () LiTS
O-Net [2] 0.7016 0.5615 0.6506 0.4842 - -
OSS-Net A 0.8592 0.7644 0.7127 0.5579 weights BraTS weights LiTS
OSS-Net B 0.8541 0.7572 0.7585 0.6154 weights BraTS weights LiTS
OSS-Net C 0.8842 0.7991 0.7616 0.6201 weights BraTS weights LiTS
OSS-Net D 0.8774 0.7876 0.7566 0.6150 weights BraTS weights LiTS

Usage

Training

To reproduce the results presented in Table 1, we provide multiple sh scripts, which can be found in the scripts folder. Please change the dataset path and CUDA devices according to your system.

To perform training runs with different settings use the command line arguments of the train_oss_net.py file. The train_oss_net.py takes the following command line arguments:

Argument Default value Info
--train False Binary flag. If set training will be performed.
--test False Binary flag. If set testing will be performed.
--cuda_devices "0, 1" String of cuda device indexes to be used. Indexes must be separated by a comma.
--cpu False Binary flag. If set all operations are performed on the CPU. (not recommended)
--epochs 50 Number of epochs to perform while training.
--batch_size 8 Number of epochs to perform while training.
--training_samples 2 ** 14 Number of coordinates to be samples during training.
--load_model "" Path to model to be loaded.
--segmentation_loss_factor 0.1 Auxiliary segmentation loss factor to be utilized.
--network_config "" Type of network configuration to be utilized (see).
--dataset "BraTS" Dataset to be utilized. ("BraTS" or "LITS")
--dataset_path "BraTS2020" Path to dataset.
--uniform_sampling False Binary flag. If set locations are sampled uniformly during training.

Please note that the naming of the different OSS-Net variants differs in the code between the paper and Table 1.

Inference

To perform inference, use the inference_oss_net.py script. The script takes the following command line arguments:

Argument Default value Info
--cuda_devices "0, 1" String of cuda device indexes to be used. Indexes must be separated by a comma.
--cpu False Binary flag. If set all operations are performed on the CPU. (not recommended)
--load_model "" Path to model to be loaded.
--network_config "" Type of network configuration to be utilized (see).
--dataset "BraTS" Dataset to be utilized. ("BraTS" or "LITS")
--dataset_path "BraTS2020" Path to dataset.

During inference the predicted occupancy voxel grid, the mesh prediction, and the label as a mesh are saved. The meshes are saved as PyTorch (.pt) files and also as .obj files. The occupancy grid is only saved as a PyTorch file.

Acknowledgements

We thank Marius Memmel and Nicolas Wagner for the insightful discussions, Alexander Christ and Tim Kircher for giving feedback on the first draft, and Markus Baier as well as Bastian Alt for aid with the computational setup.

This work was supported by the Landesoffensive für wissenschaftliche Exzellenz as part of the LOEWE Schwerpunkt CompuGene. H.K. acknowledges support from the European Re- search Council (ERC) with the consolidator grant CONSYN (nr. 773196). O.C. is supported by the Alexander von Humboldt Foundation Philipp Schwartz Initiative.

References

[1] @inproceedings{Molina2020Padé,
        title={{Pad\'{e} Activation Units: End-to-end Learning of Flexible Activation Functions in Deep Networks}},
        author={Alejandro Molina and Patrick Schramowski and Kristian Kersting},
        booktitle={International Conference on Learning Representations},
        year={2020}
}
[2] @inproceedings{Mescheder2019,
        title={{Occupancy Networks: Learning 3D Reconstruction in Function Space}},
        author={Mescheder, Lars and Oechsle, Michael and Niemeyer, Michael and Nowozin, Sebastian and Geiger, Andreas},
        booktitle={CVPR},
        pages={4460--4470},
        year={2019}
}
Owner
Christoph Reich
Autonomous systems and electrical engineering student @ Technical University of Darmstadt
Christoph Reich
Fully convolutional networks for semantic segmentation

FCN-semantic-segmentation Simple end-to-end semantic segmentation using fully convolutional networks [1]. Takes a pretrained 34-layer ResNet [2], remo

Kai Arulkumaran 186 Dec 25, 2022
Official PyTorch implemention of our paper "Learning to Rectify for Robust Learning with Noisy Labels".

WarPI The official PyTorch implemention of our paper "Learning to Rectify for Robust Learning with Noisy Labels". Run python main.py --corruption_type

Haoliang Sun 3 Sep 03, 2022
The official GitHub repository for the Argoverse 2 dataset.

Argoverse 2 API Official GitHub repository for the Argoverse 2 family of datasets. If you have any questions or run into any problems with either the

Argo AI 156 Dec 23, 2022
Hardware accelerated, batchable and differentiable optimizers in JAX.

JAXopt Installation | Examples | References Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. Installation JAXopt can be

Google 621 Jan 08, 2023
This code is a near-infrared spectrum modeling method based on PCA and pls

Nirs-Pls-Corn This code is a near-infrared spectrum modeling method based on PCA and pls 近红外光谱分析技术属于交叉领域,需要化学、计算机科学、生物科学等多领域的合作。为此,在(北邮邮电大学杨辉华老师团队)指导下

Fu Pengyou 6 Dec 17, 2022
Run containerized, rootless applications with podman

Why? restrict scope of file system access run any application without root privileges creates usable "Desktop applications" to integrate into your nor

119 Dec 27, 2022
DiSECt: Differentiable Simulator for Robotic Cutting

DiSECt: Differentiable Simulator for Robotic Cutting Website | Paper | Dataset | Video | Blog post DiSECt is a simulator for the cutting of deformable

NVIDIA Research Projects 73 Oct 29, 2022
ShapeGlot: Learning Language for Shape Differentiation

ShapeGlot: Learning Language for Shape Differentiation Created by Panos Achlioptas, Judy Fan, Robert X.D. Hawkins, Noah D. Goodman, Leonidas J. Guibas

Panos 32 Dec 23, 2022
A web application that provides real time temperature and humidity readings of a house.

About A web application which provides real time temperature and humidity readings of a house. If you're interested in the data collected so far click

Ben Thompson 3 Jan 28, 2022
A set of simple scripts to process the Imagenet-1K dataset as TFRecords and make index files for NVIDIA DALI.

Overview This is a set of simple scripts to process the Imagenet-1K dataset as TFRecords and make index files for NVIDIA DALI. Make TFRecords To run t

8 Nov 01, 2022
An Unsupervised Graph-based Toolbox for Fraud Detection

An Unsupervised Graph-based Toolbox for Fraud Detection Introduction: UGFraud is an unsupervised graph-based fraud detection toolbox that integrates s

SafeGraph 99 Dec 11, 2022
The audio-video synchronization of MKV Container Format is exploited to achieve data hiding

The audio-video synchronization of MKV Container Format is exploited to achieve data hiding, where the hidden data can be utilized for various management purposes, including hyper-linking, annotation

Maxim Zaika 1 Nov 17, 2021
Reproducing Results from A Hybrid Approach to Targeting Social Assistance

title author date output Reproducing Results from A Hybrid Approach to Targeting Social Assistance Lendie Follett and Heath Henderson 12/28/2021 html_

Lendie Follett 0 Jan 06, 2022
[ICCV21] Official implementation of the "Social NCE: Contrastive Learning of Socially-aware Motion Representations" in PyTorch.

Social-NCE + CrowdNav Website | Paper | Video | Social NCE + Trajectron | Social NCE + STGCNN This is an official implementation for Social NCE: Contr

VITA lab at EPFL 125 Dec 23, 2022
Source code for "Roto-translated Local Coordinate Framesfor Interacting Dynamical Systems"

Roto-translated Local Coordinate Frames for Interacting Dynamical Systems Source code for Roto-translated Local Coordinate Frames for Interacting Dyna

Miltiadis Kofinas 19 Nov 27, 2022
PyTorch implementations of algorithms for density estimation

pytorch-flows A PyTorch implementations of Masked Autoregressive Flow and some other invertible transformations from Glow: Generative Flow with Invert

Ilya Kostrikov 546 Dec 05, 2022
PyTorch code for 'Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning'

Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning This repository is for EMSRDPN introduced in the foll

7 Feb 10, 2022
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
Official Implementation of Domain-Aware Universal Style Transfer

Domain Aware Universal Style Transfer Official Pytorch Implementation of 'Domain Aware Universal Style Transfer' (ICCV 2021) Domain Aware Universal St

KibeomHong 80 Dec 30, 2022
PyTorch implementation of HDN(Homography Decomposition Networks) for planar object tracking

Homography Decomposition Networks for Planar Object Tracking This project is the offical PyTorch implementation of HDN(Homography Decomposition Networ

CaptainHook 48 Dec 15, 2022