[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

Overview

MDCA Calibration

This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration".

Abstract

Deep Neural Networks (DNNs) make overconfident mistakes which can prove to be probematic in deployment in safety critical applications. Calibration is aimed to enhance trust in DNNs. The goal of our proposed Multi-Class Difference in Confidence and Accuracy (MDCA) loss is to align the probability estimates in accordance with accuracy thereby enhancing the trust in DNN decisions. MDCA can be used in case of image classification, image segmentation, and natural language classification tasks.

Teaser

Above image shows comparison of classwise reliability diagrams of Cross-Entropy vs. our proposed method.

Requirements

  • Python 3.8
  • PyTorch 1.8

Directly install using pip

 pip install -r requirements.txt

Training scripts:

Refer to the scripts folder to train for every model and dataset. Overall the command to train looks like below where each argument can be changed accordingly on how to train. Also refer to dataset/__init__.py and models/__init__.py for correct arguments to train with. Argument parser can be found in utils/argparser.py.

Train with cross-entropy:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss cross_entropy 

Train with FL+MDCA: Also mention the gamma (for Focal Loss) and beta (Weight assigned to MDCA) to train FL+MDCA with

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss FL+MDCA --gamma 1.0 --beta 1.0 

Train with NLL+MDCA:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss NLL+MDCA --beta 1.0

Post Hoc Calibration:

To do post-hoc calibration, we can use the following command.

lr and patience value is used for Dirichlet calibration. To change range of grid-search in dirichlet calibration, refer to posthoc_calibrate.py.

python posthoc_calibrate.py --dataset cifar10 --model resnet56 --lr 0.001 --patience 5 --checkpoint path/to/your/trained/model

Other Experiments (Dataset Drift, Dataset Imbalance):

experiments folder contains our experiments on PACS, Rotated MNIST and Imbalanced CIFAR10. Please refer to the scripts provided to run them.

Citation

If you find our work useful in your research, please cite the following:

@InProceedings{StitchInTime,
    author    = {R. Hebbalaguppe, J. Prakash, N. Madan, C. Arora},
    title     = {A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022}
}

Contact

For questions about our paper or code, please contact any of the authors (@neelabh17, @bicycleman15, @rhebbalaguppe ) or raise an issue on GitHub.

References:

The code is adapted from the following repositories:

[1] bearpaw/pytorch-classification [2] torrvision/focal_calibration [3] Jonathan-Pearce/calibration_library

Owner
MDCA Calibration
MDCA Calibration
Code for 1st place solution in Sleep AI Challenge SNU Hospital

Sleep AI Challenge SNU Hospital 2021 Code for 1st place solution for Sleep AI Challenge (Note that the code is not fully organized) Refer to the notio

Saewon Yang 13 Jan 03, 2022
Additional environments compatible with OpenAI gym

Decentralized Control of Quadrotor Swarms with End-to-end Deep Reinforcement Learning A codebase for training reinforcement learning policies for quad

Zhehui Huang 40 Dec 06, 2022
A gesture recognition system powered by OpenPose, k-nearest neighbours, and local outlier factor.

OpenHands OpenHands is a gesture recognition system powered by OpenPose, k-nearest neighbours, and local outlier factor. Currently the system can iden

Paul Treanor 12 Jan 10, 2022
The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

ELSA: Enhanced Local Self-Attention for Vision Transformer By Jingkai Zhou, Pich

DamoCV 87 Dec 19, 2022
A Python-based development platform for automated trading systems - from backtesting to optimisation to livetrading.

AutoTrader AutoTrader is Python-based platform intended to help in the development, optimisation and deployment of automated trading systems. From sim

Kieran Mackle 485 Jan 09, 2023
Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models

Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models This repo contains a barebones implementation for the atta

16 Dec 04, 2022
[제 13회 투빅스 컨퍼런스] OK Mugle! - 장르부터 멜로디까지, Content-based Music Recommendation

Ok Mugle! 🎵 장르부터 멜로디까지, Content-based Music Recommendation 'Ok Mugle!'은 제13회 투빅스 컨퍼런스(2022.01.15)에서 진행한 음악 추천 프로젝트입니다. Description 📖 본 프로젝트에서는 Kakao

SeongBeomLEE 5 Oct 09, 2022
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Thomas Neumann 117 Nov 27, 2022
This repo tries to recognize faces in the dataset you created

YÜZ TANIMA SİSTEMİ Bu repo oluşturacağınız yüz verisetlerini tanımaya çalışan ma

Mehdi KOŞACA 2 Dec 30, 2021
Dataloader tools for language modelling

Installation: pip install lm_dataloader Design Philosophy A library to unify lm dataloading at large scale Simple interface, any tokenizer can be inte

5 Mar 25, 2022
BOVText: A Large-Scale, Multidimensional Multilingual Dataset for Video Text Spotting

BOVText: A Large-Scale, Bilingual Open World Dataset for Video Text Spotting Updated on December 10, 2021 (Release all dataset(2021 videos)) Updated o

weijiawu 47 Dec 26, 2022
UV matrix decompostion using movielens dataset

UV-matrix-decompostion-with-kfold UV matrix decompostion using movielens dataset upload the 'ratings.dat' file install the following python libraries

2 Oct 18, 2022
PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos

PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos. By adopting a unified pipeline-ba

PyKale 370 Dec 27, 2022
TransCD: Scene Change Detection via Transformer-based Architecture

TransCD: Scene Change Detection via Transformer-based Architecture

wangzhixue 29 Dec 11, 2022
Cache Requests in Deta Bases and Echo them with Deta Micros

Deta Echo Cache Leverage the awesome Deta Micros and Deta Base to cache requests and echo them as needed. Stop worrying about slow public APIs or agre

Gingerbreadfork 8 Dec 07, 2021
Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling

Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling Code for the paper: Greg Ver Steeg and Aram Galstyan. "Hamiltonian Dynamics with N

Greg Ver Steeg 25 Mar 14, 2022
DLL: Direct Lidar Localization

DLL: Direct Lidar Localization Summary This package presents DLL, a direct map-based localization technique using 3D LIDAR for its application to aeri

Service Robotics Lab 127 Dec 16, 2022
Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR

UniSpeech The family of UniSpeech: UniSpeech (ICML 2021): Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR UniSpeech-

Microsoft 282 Jan 09, 2023
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 06, 2022
PyTorch implementation of Off-policy Learning in Two-stage Recommender Systems

Off-Policy-2-Stage This repo provides a PyTorch implementation of the MovieLens experiments for the following paper: Off-policy Learning in Two-stage

Jiaqi Ma 25 Dec 12, 2022