[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

Overview

MDCA Calibration

This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration".

Abstract

Deep Neural Networks (DNNs) make overconfident mistakes which can prove to be probematic in deployment in safety critical applications. Calibration is aimed to enhance trust in DNNs. The goal of our proposed Multi-Class Difference in Confidence and Accuracy (MDCA) loss is to align the probability estimates in accordance with accuracy thereby enhancing the trust in DNN decisions. MDCA can be used in case of image classification, image segmentation, and natural language classification tasks.

Teaser

Above image shows comparison of classwise reliability diagrams of Cross-Entropy vs. our proposed method.

Requirements

  • Python 3.8
  • PyTorch 1.8

Directly install using pip

 pip install -r requirements.txt

Training scripts:

Refer to the scripts folder to train for every model and dataset. Overall the command to train looks like below where each argument can be changed accordingly on how to train. Also refer to dataset/__init__.py and models/__init__.py for correct arguments to train with. Argument parser can be found in utils/argparser.py.

Train with cross-entropy:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss cross_entropy 

Train with FL+MDCA: Also mention the gamma (for Focal Loss) and beta (Weight assigned to MDCA) to train FL+MDCA with

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss FL+MDCA --gamma 1.0 --beta 1.0 

Train with NLL+MDCA:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss NLL+MDCA --beta 1.0

Post Hoc Calibration:

To do post-hoc calibration, we can use the following command.

lr and patience value is used for Dirichlet calibration. To change range of grid-search in dirichlet calibration, refer to posthoc_calibrate.py.

python posthoc_calibrate.py --dataset cifar10 --model resnet56 --lr 0.001 --patience 5 --checkpoint path/to/your/trained/model

Other Experiments (Dataset Drift, Dataset Imbalance):

experiments folder contains our experiments on PACS, Rotated MNIST and Imbalanced CIFAR10. Please refer to the scripts provided to run them.

Citation

If you find our work useful in your research, please cite the following:

@InProceedings{StitchInTime,
    author    = {R. Hebbalaguppe, J. Prakash, N. Madan, C. Arora},
    title     = {A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022}
}

Contact

For questions about our paper or code, please contact any of the authors (@neelabh17, @bicycleman15, @rhebbalaguppe ) or raise an issue on GitHub.

References:

The code is adapted from the following repositories:

[1] bearpaw/pytorch-classification [2] torrvision/focal_calibration [3] Jonathan-Pearce/calibration_library

Owner
MDCA Calibration
MDCA Calibration
Pretrained language model and its related optimization techniques developed by Huawei Noah's Ark Lab.

Pretrained Language Model This repository provides the latest pretrained language models and its related optimization techniques developed by Huawei N

HUAWEI Noah's Ark Lab 2.6k Jan 01, 2023
ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

Ibai Gorordo 18 Nov 06, 2022
Implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021).

[PDF] | [Slides] The official implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021 Long talk) Installation Inst

MilaGraph 117 Dec 09, 2022
Skipgram Negative Sampling in PyTorch

PyTorch SGNS Word2Vec's SkipGramNegativeSampling in Python. Yet another but quite general negative sampling loss implemented in PyTorch. It can be use

Jamie J. Seol 287 Dec 14, 2022
Nicholas Lee 3 Jan 09, 2022
Distance-Ratio-Based Formulation for Metric Learning

Distance-Ratio-Based Formulation for Metric Learning Environment Python3 Pytorch (http://pytorch.org/) (version 1.6.0+cu101) json tqdm Preparing datas

Hyeongji Kim 1 Dec 07, 2022
BrainGNN - A deep learning model for data-driven discovery of functional connectivity

A deep learning model for data-driven discovery of functional connectivity https://doi.org/10.3390/a14030075 Usman Mahmood, Zengin Fu, Vince D. Calhou

Usman Mahmood 3 Aug 28, 2022
Towards Multi-Camera 3D Human Pose Estimation in Wild Environment

PanopticStudio Toolbox This repository has a toolbox to download, process, and visualize the Panoptic Studio (Panoptic) data. Note: Sep-21-2020: Curre

335 Jan 09, 2023
GE2340 project source code without credentials.

GE2340-Project-Public GE2340 project source code without credentials. Run the bot.py to start the bot Telegram: @jasperwong_ge2340_bot If the bot does

0 Feb 10, 2022
Code and data for "Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning" (EMNLP 2021).

GD-VCR Code for Broaden the Vision: Geo-Diverse Visual Commonsense Reasoning (EMNLP 2021). Research Questions and Aims: How well can a model perform o

Da Yin 24 Oct 13, 2022
Code for the published paper : Learning to recognize rare traffic sign

Improving traffic sign recognition by active search This repo contains code for the paper : "Learning to recognise rare traffic signs" How to use this

samsja 4 Jan 05, 2023
Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis

Readme File for "Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis" by Ham, Imai, and Janson. (2022) All scripts were written and

0 Jan 27, 2022
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted)

NLOS-OT Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted) Description In this reposit

Ruixu Geng(耿瑞旭) 16 Dec 16, 2022
Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21)

AdvRush Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21) Environmental Set-up Python == 3.6.12, PyTorch =

11 Dec 10, 2022
Wordle Env: A Daily Word Environment for Reinforcement Learning

Wordle Env: A Daily Word Environment for Reinforcement Learning Setup Steps: git pull [email&#

2 Mar 28, 2022
The 2nd place solution of 2021 google landmark retrieval on kaggle.

Google_Landmark_Retrieval_2021_2nd_Place_Solution The 2nd place solution of 2021 google landmark retrieval on kaggle. Environment We use cuda 11.1/pyt

229 Dec 13, 2022
[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Graph Stochastic Attention (GSAT) The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Att

85 Nov 27, 2022
Codes for Causal Semantic Generative model (CSG), the model proposed in "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS-21)

Learning Causal Semantic Representation for Out-of-Distribution Prediction This repository is the official implementation of "Learning Causal Semantic

Chang Liu 54 Dec 01, 2022
Time Series Cross-Validation -- an extension for scikit-learn

TSCV: Time Series Cross-Validation This repository is a scikit-learn extension for time series cross-validation. It introduces gaps between the traini

Wenjie Zheng 222 Jan 01, 2023