Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Overview

Improving evidential deep learning via multi task learning

It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task learning”, by Dongpin Oh and Bonggun Shin.

This repository contains the code to reproduce the Multi-task evidential neural network (MT-ENet), which uses the Lipschitz MSE loss function as the additional loss function of the evidential regression network (ENet). The Lipschitz MSE loss function can improve the accuracy of the ENet while preserving its uncertainty estimation capability, by avoiding gradient conflict with the NLL loss function—the original loss function of the ENet.

drawing

Setup

Please refer to "requirements.txt" for requring packages of this repo.

pip install -r requirements.txt

Training the ENet with the Lipschitz-MSE loss: example

from mtevi.mtevi import EvidentialMarginalLikelihood, EvidenceRegularizer, modified_mse
...
net = EvidentialNetwork() ## Evidential regression network
nll_loss = EvidentialMarginalLikelihood() ## original loss, NLL loss
reg = EvidenceRegularizer() ## evidential regularizer
mmse_loss = modified_mse ## lipschitz MSE loss
...
for inputs, labels in dataloader:
	gamma, nu, alpha, beta = net(inputs)
	loss = nll_loss(gamma, nu, alpha, beta, labels)
	loss += reg(gamma, nu, alpha, beta, labels)
	loss += mmse_loss(gamma, nu, alpha, beta, labels)
	loss.backward()	

Quick start

  • Synthetic data experiment.
python synthetic_exp.py
  • UCI regression benchmark experiments.
python uci_exp_norm -p energy
  • Drug target affinity (DTA) regression task on KIBA and Davis datasets.
python train_evinet.py -o test --type davis -f 0 --evi # ENet
python train_evinet.py -o test --type davis -f 0  # MT-ENet
  • Gradient conflict experiment on the DTA benchmarks
python check_conflict.py --type davis -f 0 # Conflict between the Lipschitz MSE (proposed) and NLL loss. 
python check_conflict.py --type davis -f 0 --abl # Conflict between the simple MSE loss and NLL loss.

Characteristic of the Lipschitz MSE loss

drawing

  • The Lipschitz MSE loss function can support training the ENet to more accurately predicts target values.
  • It regularizes its gradient to prevent gradient conflict with the NLL loss--the original loss function--if the NLL loss increases predictive uncertainty of the ENet.
  • Please check our paper for details.
Owner
deargen
deargen
JupyterNotebook - C/C++, Javascript, HTML, LaTex, Shell scripts in Jupyter Notebook Also run them on remote computer

JupyterNotebook Read, write and execute C, C++, Javascript, Shell scripts, HTML, LaTex in jupyter notebook, And also execute them on remote computer R

1 Jan 09, 2022
[NeurIPS 2021] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data (NeurIPS 2021) This repository will provide the official PyTorch implementa

Liming Jiang 238 Nov 25, 2022
Rapid experimentation and scaling of deep learning models on molecular and crystal graphs.

LitMatter A template for rapid experimentation and scaling deep learning models on molecular and crystal graphs. How to use Clone this repository and

Nathan Frey 32 Dec 06, 2022
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python = 3.8 pytorch = 1.8.0

DV Lab 63 Dec 16, 2022
Chinese license plate recognition

AgentCLPR 简介 一个基于 ONNXRuntime、AgentOCR 和 License-Plate-Detector 项目开发的中国车牌检测识别系统。 车牌识别效果 支持多种车牌的检测和识别(其中单层车牌识别效果较好): 单层车牌: [[[[373, 282], [69, 284],

AgentMaker 26 Dec 25, 2022
Uni-Fold: Training your own deep protein-folding models

Uni-Fold: Training your own deep protein-folding models. This package provides an implementation of a trainable, Transformer-based deep protein foldin

DP Technology 187 Jan 04, 2023
Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes

Neural Scene Flow Fields PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021 [Projec

Zhengqi Li 583 Dec 30, 2022
Python Classes: Medical Insurance Project using Object Oriented Programming Concepts

Medical-Insurance-Project-OOP Python Classes: Medical Insurance Project using Object Oriented Programming Concepts Classes are an incredibly useful pr

Hugo B. 0 Feb 04, 2022
Pixel-wise segmentation on VOC2012 dataset using pytorch.

PiWiSe Pixel-wise segmentation on the VOC2012 dataset using pytorch. FCN SegNet PSPNet UNet RefineNet For a more complete implementation of segmentati

Bodo Kaiser 378 Dec 30, 2022
TOOD: Task-aligned One-stage Object Detection, ICCV2021 Oral

One-stage object detection is commonly implemented by optimizing two sub-tasks: object classification and localization, using heads with two parallel branches, which might lead to a certain level of

264 Jan 09, 2023
A Self-Supervised Contrastive Learning Framework for Aspect Detection

AspDecSSCL A Self-Supervised Contrastive Learning Framework for Aspect Detection This repository is a pytorch implementation for the following AAAI'21

Tian Shi 30 Dec 28, 2022
🔪 Elimination based Lightweight Neural Net with Pretrained Weights

ELimNet ELimNet: Eliminating Layers in a Neural Network Pretrained with Large Dataset for Downstream Task Removed top layers from pretrained Efficient

snoop2head 4 Jul 12, 2022
The fastest way to visualize GradCAM with your Keras models.

VizGradCAM VizGradCam is the fastest way to visualize GradCAM in Keras models. GradCAM helps with providing visual explainability of trained models an

58 Nov 19, 2022
Vit-ImageClassification - Pytorch ViT for Image classification on the CIFAR10 dataset

Vit-ImageClassification Introduction This project uses ViT to perform image clas

Kaicheng Yang 4 Jun 01, 2022
PyTorch implementations of the paper: "Learning Independent Instance Maps for Crowd Localization"

IIM - Crowd Localization This repo is the official implementation of paper: Learning Independent Instance Maps for Crowd Localization. The code is dev

tao han 91 Nov 10, 2022
Face2webtoon - Despite its importance, there are few previous works applying I2I translation to webtoon.

Despite its importance, there are few previous works applying I2I translation to webtoon. I collected dataset from naver webtoon 연애혁명 and tried to transfer human faces to webtoon domain.

이상윤 64 Oct 19, 2022
📚 A collection of all the Deep Learning Metrics that I came across which are not accuracy/loss.

📚 A collection of all the Deep Learning Metrics that I came across which are not accuracy/loss.

Rahul Vigneswaran 1 Jan 17, 2022
This repository is an implementation of paper : Improving the Training of Graph Neural Networks with Consistency Regularization

CRGNN Paper : Improving the Training of Graph Neural Networks with Consistency Regularization Environments Implementing environment: GeForce RTX™ 3090

THUDM 28 Dec 09, 2022
Repository accompanying the "Sign Pose-based Transformer for Word-level Sign Language Recognition" paper

by Matyáš Boháček and Marek Hrúz, University of West Bohemia Should you have any questions or inquiries, feel free to contact us here. Repository acco

Matyáš Boháček 30 Dec 30, 2022
The official homepage of the COCO-Stuff dataset.

The COCO-Stuff dataset Holger Caesar, Jasper Uijlings, Vittorio Ferrari Welcome to official homepage of the COCO-Stuff [1] dataset. COCO-Stuff augment

Holger Caesar 715 Dec 31, 2022