Learning Confidence for Out-of-Distribution Detection in Neural Networks

Overview

Learning Confidence Estimates for Neural Networks

This repository contains the code for the paper Learning Confidence for Out-of-Distribution Detection in Neural Networks. In this work, we demonstrate how to augment neural networks with a confidence estimation branch, which can be used to identify misclassified and out-of-distribution examples.

To learn confidence estimates during training, we provide the neural network with "hints" towards the correct output whenever it exhibits low confidence in its predictions. Hints are provided by pushing the prediction closer to the target distribution via interpolation, where the amount of interpolation proportional to the network's confidence that its prediction is correct. To discourage the network from always asking for free hints, a small penalty is applied whenever it is not confident. As a result, the network learns to only produce low confidence estimates when it is likely to make an incorrect prediction.

Bibtex:

@article{devries2018learning,
  title={Learning Confidence for Out-of-Distribution Detection in Neural Networks},
  author={DeVries, Terrance and Taylor, Graham W},
  journal={arXiv preprint arXiv:1802.04865},
  year={2018}
}

Results and Usage

We evalute our method on the task of out-of-distribution detection using three different neural network architectures: DenseNet, WideResNet, and VGG. CIFAR-10 and SVHN are used as the in-distribution datasets, while TinyImageNet, LSUN, iSUN, uniform noise, and Gaussian noise are used as the out-of-distribution datasets. Definitions of evaluation metrics can be found in the paper.

Dependencies

PyTorch v0.3.0
tqdm
visdom
seaborn
Pillow
scikit-learn

Training

Train a model with a confidence estimator with train.py. During training you can use visdom to see a histogram of confidence estimates from the test set. Training logs will be stored in the logs/ folder, while checkpoints are stored in the checkpoints/ folder.

Args Options Description
dataset cifar10,
svhn
Selects which dataset to train on.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use.
batch_size [int] Number of samples per batch.
epochs [int] Number of epochs for training.
seed [int] Random seed.
learning_rate [float] Learning rate.
data_augmentation Train with standard data augmentation (random flipping and translation).
cutout [int] Indicates the patch size to use for Cutout. If 0, Cutout is not used.
budget [float] Controls how often the network can choose have low confidence in its prediction. Increasing the budget will bias the output towards low confidence predictions, while decreasing the budget will produce more high confidence predictions.
baseline Train the model without the confidence branch.

Use the following settings to replicate the experiments from the paper:

VGG13 on CIFAR-10

python train.py --dataset cifar10 --model vgg13 --budget 0.3 --data_augmentation --cutout 16

WideResNet on CIFAR-10

python train.py --dataset cifar10 --model wideresnet --budget 0.3 --data_augmentation --cutout 16

DenseNet on CIFAR-10

python train.py --dataset cifar10 --model densenet --budget 0.3 --epochs 300 --batch_size 64 --data_augmentation --cutout 16

VGG13 on SVHN

python train.py --dataset svhn --model vgg13 --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

WideResNet on SVHN

python train.py --dataset svhn --model wideresnet --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

DenseNet on SVHN

python train.py --dataset svhn --model densenet --budget 0.3 --learning_rate 0.01 --epochs 300 --batch_size 64  --data_augmentation --cutout 20

Out-of-distribution detection

Evaluate a trained model with out_of_distribution_detection.py. Before running this you will need to download the out-of-distribution datasets from Shiyu Liang's ODIN github repo and modify the data paths in the file according to where you saved the datasets.

Args Options Description
ind_dataset cifar10,
svhn
Indicates which dataset to use as in-distribution. Should be the same one that the model was trained on.
ood_dataset tinyImageNet_crop,
tinyImageNet_resize,
LSUN_crop,
LSUN_resize,
iSUN,
Uniform,
Gaussian,
all
Indicates which dataset to use as the out-of-distribution datset.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use. Should be the same one that the model was trained on.
process baseline,
ODIN,
confidence,
confidence_scaling
Indicates which method to use for out-of-distribution detection. Baseline uses the maximum softmax probability. ODIN applies temperature scaling and input pre-processing to the baseline method. Confidence uses the learned confidence estimates. Confidence scaling applies input pre-processing to the confidence estimates.
batch_size [int] Number of samples per batch.
T [float] Temperature to use for temperature scaling.
epsilon [float] Noise magnitude to use for input pre-processing.
checkpoint [str] Filename of trained model checkpoint. Assumes the file is in the checkpoints/ folder. A .pt extension is also automatically added to the filename.
validation Use this flag for fine-tuning T and epsilon. If flag is on, the script will only evaluate on the first 1000 samples in the out-of-distribution dataset. If flag is not used, the remaining samples are used for evaluation. Based on validation procedure from ODIN.

Example commands for running the out-of-distribution detection script:

Baseline

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset all --model vgg13 --process baseline --checkpoint svhn_vgg13_budget_0.0_seed_0

ODIN

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset tinyImageNet_resize --model densenet --process ODIN --T 1000 --epsilon 0.001 --checkpoint cifar10_densenet_budget_0.0_seed_0

Confidence

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset LSUN_crop --model vgg13 --process confidence --checkpoint cifar10_vgg13_budget_0.3_seed_0

Confidence scaling

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset iSUN --model wideresnet --process confidence_scaling --epsilon 0.001 --checkpoint svhn_wideresnet_budget_0.3_seed_0
A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 06, 2023
City-seeds - A random generator of cultural characteristics intended to spark ideas and help draw threads

City Seeds This is a random generator of cultural characteristics intended to sp

Aydin O'Leary 2 Mar 12, 2022
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

20 May 28, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023
Learning to Map Large-scale Sparse Graphs on Memristive Crossbar

Release of AutoGMap:Learning to Map Large-scale Sparse Graphs on Memristive Crossbar For reproduction of our searched model, the Ubuntu OS is recommen

2 Aug 23, 2022
Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral]

Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral] Learning to Disambiguate Strongly In

Zicong Fan 40 Dec 22, 2022
An implementation of Fastformer: Additive Attention Can Be All You Need in TensorFlow

Fast Transformer This repo implements Fastformer: Additive Attention Can Be All You Need by Wu et al. in TensorFlow. Fast Transformer is a Transformer

Rishit Dagli 139 Dec 28, 2022
UCSD Oasis platform

oasis UCSD Oasis platform Local project setup Install Docker Compose and make sure you have Pip installed Clone the project and go to the project fold

InSTEDD 4 Jun 16, 2021
Most popular metrics used to evaluate object detection algorithms.

Most popular metrics used to evaluate object detection algorithms.

Rafael Padilla 4.4k Dec 25, 2022
Code for the paper "Unsupervised Contrastive Learning of Sound Event Representations", ICASSP 2021.

Unsupervised Contrastive Learning of Sound Event Representations This repository contains the code for the following paper. If you use this code or pa

Eduardo Fonseca 81 Dec 22, 2022
(CVPR 2022) A minimalistic mapless end-to-end stack for joint perception, prediction, planning and control for self driving.

LAV Learning from All Vehicles Dian Chen, Philipp Krähenbühl CVPR 2022 (also arXiV 2203.11934) This repo contains code for paper Learning from all veh

Dian Chen 300 Dec 15, 2022
Code for TIP 2017 paper --- Illumination Decomposition for Photograph with Multiple Light Sources.

Illumination_Decomposition Code for TIP 2017 paper --- Illumination Decomposition for Photograph with Multiple Light Sources. This code implements the

QAY 7 Nov 15, 2020
PFFDTD is an open-source FDTD simulator for 3D room acoustics

PFFDTD is an open-source FDTD simulator for 3D room acoustics

Brian Hamilton 34 Nov 24, 2022
Code for NeurIPS 2021 paper 'Spatio-Temporal Variational Gaussian Processes'

Spatio-Temporal Variational GPs This repository is the official implementation of the methods in the publication: O. Hamelijnck, W.J. Wilkinson, N.A.

AaltoML 26 Sep 16, 2022
A knowledge base construction engine for richly formatted data

Fonduer is a Python package and framework for building knowledge base construction (KBC) applications from richly formatted data. Note that Fonduer is

HazyResearch 386 Dec 05, 2022
Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling".

PSSL Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling". It consists of the pre-tra

2 Dec 21, 2021
Tensorboard for pytorch (and chainer, mxnet, numpy, ...)

tensorboardX Write TensorBoard events with simple function call. The current release (v2.3) is tested on anaconda3, with PyTorch 1.8.1 / torchvision 0

Tzu-Wei Huang 7.5k Dec 28, 2022
Apply AnimeGAN-v2 across frames of a video clip

title emoji colorFrom colorTo sdk app_file pinned AnimeGAN-v2 For Videos 🔥 blue red gradio app.py false AnimeGAN-v2 For Videos Apply AnimeGAN-v2 acro

Nathan Raw 36 Oct 18, 2022
RM Operation can equivalently convert ResNet to VGG, which is better for pruning; and can help RepVGG perform better when the depth is large.

RM Operation can equivalently convert ResNet to VGG, which is better for pruning; and can help RepVGG perform better when the depth is large.

184 Jan 04, 2023
NP DRAW paper released code

NP-DRAW: A Non-Parametric Structured Latent Variable Model for Image Generation This repo contains the official implementation for the NP-DRAW paper.

ZENG Xiaohui 22 Mar 13, 2022