Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Overview

Is it Time to Replace CNNs with Transformers for Medical Images?

Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Convolutional Neural Networks (CNNs) have reigned for a decade as the de facto approach to automated medical image diagnosis. Recently, vision transformers (ViTs) have appeared as a competitive alternative to CNNs, yielding similar levels of performance while possessing several interesting properties that could prove beneficial for medical imaging tasks. In this work, we explore whether it is time to move to transformer-based models or if we should keep working with CNNs - can we trivially switch to transformers? If so, what are the advantages and drawbacks of switching to ViTs for medical image diagnosis? We consider these questions in a series of experiments on three mainstream medical image datasets. Our findings show that, while CNNs perform better when trained from scratch, off-the-shelf vision transformers using default hyperparameters are on par with CNNs when pretrained on ImageNet, and outperform their CNN counterparts when pretrained using self-supervision.

Enviroment setup

To build using the docker file use the following command
docker build -f Dockerfile -t med_trans \
--build-arg UID=$(id -u) \
--build-arg GID=$(id -g) \
--build-arg USER=$(whoami) \
--build-arg GROUP=$(id -g -n) .

Usage:

  • Training: python classification.py
  • Training with DINO: python classification.py --dino
  • Testing (using json file): python classification.py --test
  • Testing (using saved checkpoint): python classification.py --checkpoint CheckpointName --test
  • Fine tune the learning rate: python classification.py --lr_finder

Configuration (json file)

  • dataset_params
    • dataset: Name of the dataset (ISIC2019, APTOS2019, DDSM)
    • data_location: Location that the datasets are located
    • train_transforms: Defines the augmentations for the training set
    • val_transforms: Defines the augmentations for the validation set
    • test_transforms: Defines the augmentations for the test set
  • dataloader_params: Defines the dataloader parameters (batch size, num_workers etc)
  • model_params
    • backbone_type: type of the backbone model (e.g. resnet50, deit_small)
    • transformers_params: Additional hyperparameters for the transformers
      • img_size: The size of the input images
      • patch_size: The patch size to use for patching the input
      • pretrained_type: If supervised it loads ImageNet weights that come from supervised learning. If dino it loads ImageNet weights that come from sefl-supervised learning with DINO.
    • pretrained: If True, it uses ImageNet pretrained weights
    • freeze_backbone: If True, it freezes the backbone network
    • DINO: It controls the hyperparameters for when training with DINO
  • optimization_params: Defines learning rate, weight decay, learning rate schedule etc.
    • optimizer: The default optimizer's parameters
      • type: The optimizer's type
      • autoscale_rl: If True it scales the learning rate based on the bach size
      • params: Defines the learning rate and the weght decay value
    • LARS_params: If use=True and bach size >= batch_act_thresh it uses LARS as optimizer
    • scheduler: Defines the learning rate schedule
      • type: A list of schedulers to use
      • params: Sets the hyperparameters of the optimizers
  • training_params: Defines the training parameters
    • model_name: The model's name
    • val_every: Sets the frequency of the valiidation step (epochs - float)
    • log_every: Sets the frequency of the logging (iterations - int)
    • save_best_model: If True it will save the bast model based on the validation metrics
    • log_embeddings: If True it creates U-maps on each validation step
    • knn_eval: If True, during validation it will also calculate the scores based on knn evalutation
    • grad_clipping: If > 0, it clips the gradients
    • use_tensorboard: If True, it will use tensorboard for logging instead of wandb
    • use_mixed_precision: If True, it will use mixed precision
    • save_dir: The dir to save the model's checkpoints etc.
  • system_params: Defines if GPUs are used, which GPUs etc.
  • log_params: Project and run name for the logger (we are using Weights & Biases by default)
  • lr_finder: Define the learning rate parameters
    • grid_search_params
      • min_pow, min_pow: The min and max power of 10 for the search
      • resolution: How many different learning rates to try
      • n_epochs: maximum epochs of the training session
      • random_lr: If True, it uses random learning rates withing the accepted range
      • keep_schedule: If True, it keeps the learning rate schedule
      • report_intermediate_steps: If True, it logs if validates throughout the training sessions
  • transfer_learning_params: Turns on or off transfer learning from pretrained models
    • use_pretrained: If True, it will use a pretrained model as a backbone
    • pretrained_model_name: The pretrained model's name
    • pretrained_path: If the prerained model's dir
Owner
Christos Matsoukas
PhD student in Deep Learning @ KTH Royal Institute of Technology
Christos Matsoukas
OpenFed: A Comprehensive and Versatile Open-Source Federated Learning Framework

OpenFed: A Comprehensive and Versatile Open-Source Federated Learning Framework Introduction OpenFed is a foundational library for federated learning

25 Dec 12, 2022
Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation (CVPR 2021)

Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation Input Image Initial CAM Successive Maps with adversar

Jungbeom Lee 110 Dec 07, 2022
Self-Supervised Pillar Motion Learning for Autonomous Driving (CVPR 2021)

Self-Supervised Pillar Motion Learning for Autonomous Driving Chenxu Luo, Xiaodong Yang, Alan Yuille Self-Supervised Pillar Motion Learning for Autono

QCraft 101 Dec 05, 2022
Single Image Super-Resolution (SISR) with SRResNet, EDSR and SRGAN

Single Image Super-Resolution (SISR) with SRResNet, EDSR and SRGAN Introduction Image super-resolution (SR) is the process of recovering high-resoluti

8 Apr 15, 2022
Codes for CIKM'21 paper 'Self-Supervised Graph Co-Training for Session-based Recommendation'.

COTREC Codes for CIKM'21 paper 'Self-Supervised Graph Co-Training for Session-based Recommendation'. Requirements: Python 3.7, Pytorch 1.6.0 Best Hype

Xin Xia 42 Dec 09, 2022
Image Recognition using Pytorch

PyTorch Project Template A simple and well designed structure is essential for any Deep Learning project, so after a lot practice and contributing in

Sarat Chinni 1 Nov 02, 2021
Image segmentation with private İstanbul Dataset

Image Segmentation This repo was created for academic research and test result. Repo will update after academic article online. This repo contains wei

İrem KÖMÜRCÜ 9 Dec 11, 2022
FaceAnon - Anonymize people in images and videos using yolov5-crowdhuman

Face Anonymizer Blur faces from image and video files in /input/ folder. Require

22 Nov 03, 2022
CenterNet:Objects as Points目标检测模型在Pytorch当中的实现

CenterNet:Objects as Points目标检测模型在Pytorch当中的实现

Bubbliiiing 267 Dec 29, 2022
Implementation for NeurIPS 2021 Submission: SparseFed

READ THIS FIRST This repo is an anonymized version of an existing repository of GitHub, for the AIStats 2021 submission: SparseFed: Mitigating Model P

2 Jun 15, 2022
Machine learning framework for both deep learning and traditional algorithms

NeoML is an end-to-end machine learning framework that allows you to build, train, and deploy ML models. This framework is used by ABBYY engineers for

NeoML 704 Dec 27, 2022
Rendering color and depth images for ShapeNet models.

Color & Depth Renderer for ShapeNet This library includes the tools for rendering multi-view color and depth images of ShapeNet models. Physically bas

Yinyu Nie 41 Dec 19, 2022
Random Erasing Data Augmentation. Experiments on CIFAR10, CIFAR100 and Fashion-MNIST

Random Erasing Data Augmentation =============================================================== black white random This code has the source code for

Zhun Zhong 654 Dec 26, 2022
(CVPR 2022 - oral) Multi-View Depth Estimation by Fusing Single-View Depth Probability with Multi-View Geometry

Multi-View Depth Estimation by Fusing Single-View Depth Probability with Multi-View Geometry Official implementation of the paper Multi-View Depth Est

Bae, Gwangbin 138 Dec 28, 2022
Python SDK for building, training, and deploying ML models

Overview of Kubeflow Fairing Kubeflow Fairing is a Python package that streamlines the process of building, training, and deploying machine learning (

Kubeflow 325 Dec 13, 2022
Code for CMaskTrack R-CNN (proposed in Occluded Video Instance Segmentation)

CMaskTrack R-CNN for OVIS This repo serves as the official code release of the CMaskTrack R-CNN model on the Occluded Video Instance Segmentation data

Q . J . Y 61 Nov 25, 2022
A Python Reconnection Tool for alt:V

altv-reconnect What? It invokes a reconnect in the altV Client Dev Console. You get to determine when your local client should reconnect when developi

8 Jun 30, 2022
Python library for loading and using triangular meshes.

Trimesh is a pure Python (2.7-3.4+) library for loading and using triangular meshes with an emphasis on watertight surfaces. The goal of the library i

Michael Dawson-Haggerty 2.2k Jan 07, 2023
CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image.

CoReNet CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image. It produces coherent reconstructions, where all objec

Google Research 80 Dec 25, 2022
Rayvens makes it possible for data scientists to access hundreds of data services within Ray with little effort.

Rayvens augments Ray with events. With Rayvens, Ray applications can subscribe to event streams, process and produce events. Rayvens leverages Apache

CodeFlare 32 Dec 25, 2022