Learning Confidence for Out-of-Distribution Detection in Neural Networks

Overview

Learning Confidence Estimates for Neural Networks

This repository contains the code for the paper Learning Confidence for Out-of-Distribution Detection in Neural Networks. In this work, we demonstrate how to augment neural networks with a confidence estimation branch, which can be used to identify misclassified and out-of-distribution examples.

To learn confidence estimates during training, we provide the neural network with "hints" towards the correct output whenever it exhibits low confidence in its predictions. Hints are provided by pushing the prediction closer to the target distribution via interpolation, where the amount of interpolation proportional to the network's confidence that its prediction is correct. To discourage the network from always asking for free hints, a small penalty is applied whenever it is not confident. As a result, the network learns to only produce low confidence estimates when it is likely to make an incorrect prediction.

Bibtex:

@article{devries2018learning,
  title={Learning Confidence for Out-of-Distribution Detection in Neural Networks},
  author={DeVries, Terrance and Taylor, Graham W},
  journal={arXiv preprint arXiv:1802.04865},
  year={2018}
}

Results and Usage

We evalute our method on the task of out-of-distribution detection using three different neural network architectures: DenseNet, WideResNet, and VGG. CIFAR-10 and SVHN are used as the in-distribution datasets, while TinyImageNet, LSUN, iSUN, uniform noise, and Gaussian noise are used as the out-of-distribution datasets. Definitions of evaluation metrics can be found in the paper.

Dependencies

PyTorch v0.3.0
tqdm
visdom
seaborn
Pillow
scikit-learn

Training

Train a model with a confidence estimator with train.py. During training you can use visdom to see a histogram of confidence estimates from the test set. Training logs will be stored in the logs/ folder, while checkpoints are stored in the checkpoints/ folder.

Args Options Description
dataset cifar10,
svhn
Selects which dataset to train on.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use.
batch_size [int] Number of samples per batch.
epochs [int] Number of epochs for training.
seed [int] Random seed.
learning_rate [float] Learning rate.
data_augmentation Train with standard data augmentation (random flipping and translation).
cutout [int] Indicates the patch size to use for Cutout. If 0, Cutout is not used.
budget [float] Controls how often the network can choose have low confidence in its prediction. Increasing the budget will bias the output towards low confidence predictions, while decreasing the budget will produce more high confidence predictions.
baseline Train the model without the confidence branch.

Use the following settings to replicate the experiments from the paper:

VGG13 on CIFAR-10

python train.py --dataset cifar10 --model vgg13 --budget 0.3 --data_augmentation --cutout 16

WideResNet on CIFAR-10

python train.py --dataset cifar10 --model wideresnet --budget 0.3 --data_augmentation --cutout 16

DenseNet on CIFAR-10

python train.py --dataset cifar10 --model densenet --budget 0.3 --epochs 300 --batch_size 64 --data_augmentation --cutout 16

VGG13 on SVHN

python train.py --dataset svhn --model vgg13 --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

WideResNet on SVHN

python train.py --dataset svhn --model wideresnet --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

DenseNet on SVHN

python train.py --dataset svhn --model densenet --budget 0.3 --learning_rate 0.01 --epochs 300 --batch_size 64  --data_augmentation --cutout 20

Out-of-distribution detection

Evaluate a trained model with out_of_distribution_detection.py. Before running this you will need to download the out-of-distribution datasets from Shiyu Liang's ODIN github repo and modify the data paths in the file according to where you saved the datasets.

Args Options Description
ind_dataset cifar10,
svhn
Indicates which dataset to use as in-distribution. Should be the same one that the model was trained on.
ood_dataset tinyImageNet_crop,
tinyImageNet_resize,
LSUN_crop,
LSUN_resize,
iSUN,
Uniform,
Gaussian,
all
Indicates which dataset to use as the out-of-distribution datset.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use. Should be the same one that the model was trained on.
process baseline,
ODIN,
confidence,
confidence_scaling
Indicates which method to use for out-of-distribution detection. Baseline uses the maximum softmax probability. ODIN applies temperature scaling and input pre-processing to the baseline method. Confidence uses the learned confidence estimates. Confidence scaling applies input pre-processing to the confidence estimates.
batch_size [int] Number of samples per batch.
T [float] Temperature to use for temperature scaling.
epsilon [float] Noise magnitude to use for input pre-processing.
checkpoint [str] Filename of trained model checkpoint. Assumes the file is in the checkpoints/ folder. A .pt extension is also automatically added to the filename.
validation Use this flag for fine-tuning T and epsilon. If flag is on, the script will only evaluate on the first 1000 samples in the out-of-distribution dataset. If flag is not used, the remaining samples are used for evaluation. Based on validation procedure from ODIN.

Example commands for running the out-of-distribution detection script:

Baseline

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset all --model vgg13 --process baseline --checkpoint svhn_vgg13_budget_0.0_seed_0

ODIN

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset tinyImageNet_resize --model densenet --process ODIN --T 1000 --epsilon 0.001 --checkpoint cifar10_densenet_budget_0.0_seed_0

Confidence

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset LSUN_crop --model vgg13 --process confidence --checkpoint cifar10_vgg13_budget_0.3_seed_0

Confidence scaling

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset iSUN --model wideresnet --process confidence_scaling --epsilon 0.001 --checkpoint svhn_wideresnet_budget_0.3_seed_0
A Pytorch implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE

SMU_pytorch A Pytorch Implementation of SMU: SMOOTH ACTIVATION FUNCTION FOR DEEP NETWORKS USING SMOOTHING MAXIMUM TECHNIQUE arXiv https://arxiv.org/ab

Fuhang 36 Dec 24, 2022
Small repo describing how to use Hugging Face's Wav2Vec2 with PyCTCDecode

🤗 Transformers Wav2Vec2 + PyCTCDecode Introduction This repo shows how 🤗 Transformers can be used in combination with kensho-technologies's PyCTCDec

Patrick von Platen 102 Oct 22, 2022
Lightweight tool to perform MITM attack on local network

ARPSpy - A lightweight tool to perform MITM attack Using many library to perform ARP Spoof and auto-sniffing HTTP packet containing credential. (Never

MinhItachi 8 Aug 28, 2022
A series of Jupyter notebooks with Chinese comment that walk you through the fundamentals of Machine Learning and Deep Learning in python using Scikit-Learn and TensorFlow.

Hands-on-Machine-Learning 目的 这份笔记旨在帮助中文学习者以一种较快较系统的方式入门机器学习, 是在学习Hands-on Machine Learning with Scikit-Learn and TensorFlow这本书的 时候做的个人笔记: 此项目的可取之处 原书的

Baymax 1.5k Dec 21, 2022
Implementation of the paper "Self-Promoted Prototype Refinement for Few-Shot Class-Incremental Learning"

Self-Promoted Prototype Refinement for Few-Shot Class-Incremental Learning This is the implementation of the paper "Self-Promoted Prototype Refinement

Kai Zhu 78 Dec 02, 2022
A python script to convert images to animated sus among us crewmate twerk jifs as seen on r/196

img_sussifier A python script to convert images to animated sus among us crewmate twerk jifs as seen on r/196 Examples How to use install python pip i

41 Sep 30, 2022
ISNAS-DIP: Image Specific Neural Architecture Search for Deep Image Prior [CVPR 2022]

ISNAS-DIP: Image-Specific Neural Architecture Search for Deep Image Prior (CVPR 2022) Metin Ersin Arican*, Ozgur Kara*, Gustav Bredell, Ender Konukogl

Özgür Kara 24 Dec 18, 2022
This is the official implementation of "One Question Answering Model for Many Languages with Cross-lingual Dense Passage Retrieval".

CORA This is the official implementation of the following paper: Akari Asai, Xinyan Yu, Jungo Kasai and Hannaneh Hajishirzi. One Question Answering Mo

Akari Asai 59 Dec 28, 2022
PyTorch implementation for the Neuro-Symbolic Sudoku Solver leveraging the power of Neural Logic Machines (NLM)

Neuro-Symbolic Sudoku Solver PyTorch implementation for the Neuro-Symbolic Sudoku Solver leveraging the power of Neural Logic Machines (NLM). Please n

Ashutosh Hathidara 60 Dec 10, 2022
Simple torch.nn.module implementation of Alias-Free-GAN style filter and resample

Alias-Free-Torch Simple torch module implementation of Alias-Free GAN. This repository including Alias-Free GAN style lowpass sinc filter @filter.py A

이준혁(Junhyeok Lee) 64 Dec 22, 2022
VisualGPT: Data-efficient Adaptation of Pretrained Language Models for Image Captioning

VisualGPT Our Paper VisualGPT: Data-efficient Adaptation of Pretrained Language Models for Image Captioning Main Architecture of Our VisualGPT Downloa

Vision CAIR Research Group, KAUST 140 Dec 28, 2022
[ICML 2021] Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data

Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data This repo provides the source code & data of our paper: Break-It-Fix-It: Unsupervised

Michihiro Yasunaga 86 Nov 30, 2022
An implementation for the ICCV 2021 paper Deep Permutation Equivariant Structure from Motion.

Deep Permutation Equivariant Structure from Motion Paper | Poster This repository contains an implementation for the ICCV 2021 paper Deep Permutation

72 Dec 27, 2022
Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021

SNN_Calibration Pytorch Implementation of Spiking Neural Networks Calibration, ICML 2021 Feature Comparison of SNN calibration: Features SNN Direct Tr

Yuhang Li 60 Dec 27, 2022
Ground truth data for the Optical Character Recognition of Historical Classical Commentaries.

OCR Ground Truth for Historical Commentaries The dataset OCR ground truth for historical commentaries (GT4HistComment) was created from the public dom

Ajax Multi-Commentary 3 Sep 08, 2022
Semi-Supervised Learning, Object Detection, ICCV2021

End-to-End Semi-Supervised Object Detection with Soft Teacher By Mengde Xu*, Zheng Zhang*, Han Hu, Jianfeng Wang, Lijuan Wang, Fangyun Wei, Xiang Bai,

Microsoft 789 Dec 27, 2022
The official PyTorch code for NeurIPS 2021 ML4AD Paper, "Does Thermal data make the detection systems more reliable?"

MultiModal-Collaborative (MMC) Learning Framework for integrating RGB and Thermal spectral modalities This is the official code for NeurIPS 2021 Machi

NeurAI 12 Nov 02, 2022
A PyTorch Toolbox for Face Recognition

FaceX-Zoo FaceX-Zoo is a PyTorch toolbox for face recognition. It provides a training module with various supervisory heads and backbones towards stat

JDAI-CV 1.6k Jan 06, 2023
VD-BERT: A Unified Vision and Dialog Transformer with BERT

VD-BERT: A Unified Vision and Dialog Transformer with BERT PyTorch Code for the following paper at EMNLP2020: Title: VD-BERT: A Unified Vision and Dia

Salesforce 44 Nov 01, 2022
QTool: A Low-bit Quantization Toolbox for Deep Neural Networks in Computer Vision

This project provides abundant choices of quantization strategies (such as the quantization algorithms, training schedules and empirical tricks) for quantizing the deep neural networks into low-bit c

Monash Green AI Lab 51 Dec 10, 2022