Python package facilitating the use of Bayesian Deep Learning methods with Variational Inference for PyTorch

Overview

PyVarInf

PyVarInf provides facilities to easily train your PyTorch neural network models using variational inference.

Bayesian Deep Learning with Variational Inference

Bayesian Deep Learning

Assume we have a dataset D = {(x1, y1), ..., (xn, yn)} where the x's are the inputs and the y's the outputs. The problem is to predict the y's from the x's. Further assume that p(D|θ) is the output of a neural network with weights θ. The network loss is defined as

Usually, when training a neural network, we try to find the parameter θ* which minimizes Ln(θ).

In Bayesian Inference, the problem is instead to study the posterior distribution of the weights given the data. Assume we have a prior α over ℝd. The posterior is

This can be used for model selection, or prediction with Bayesian Model Averaging.

Variational Inference

It is usually impossible to analytically compute the posterior distribution, especially with models as complex as neural networks. Variational Inference adress this problem by approximating the posterior p(θ|D) by a parametric distribution q(θ|φ) where φ is a parameter. The problem is then not to learn a parameter θ* but a probability distribution q(θ|φ) minimizing

F is called the variational free energy.

This idea was originally introduced for deep learning by Hinton and Van Camp [5] as a way to use neural networks for Minimum Description Length [3]. MDL aims at minimizing the number of bits used to encode the whole dataset. Variational inference introduces one of many data encoding schemes. Indeed, F can be interpreted as the total description length of the dataset D, when we first encode the model, then encode the part of the data not explained by the model:

  • LC(φ) = KL(q(.|φ)||α) is the complexity loss. It measures (in nats) the quantity of information contained in the model. It is indeed possible to encode the model in LC(φ) nats, with the bits-back code [4].
  • LE(φ) = Eθ ~ q(θ|φ)[Ln(θ)] is the error loss. It measures the necessary quantity of information for encoding the data D with the model. This code length can be achieved with a Shannon-Huffman code for instance.

Therefore F(φ) = LC(φ) + LE(φ) can be rephrased as an MDL loss function which measures the total encoding length of the data.

Practical Variational Optimisation

In practice, we define φ = (µ, σ) in ℝd x ℝd, and q(.|φ) = N(µ, Σ) the multivariate distribution where Σ = diag(σ12, ..., σd2), and we want to find the optimal µ* and σ*.

With this choice of a gaussian posterior, a Monte Carlo estimate of the gradient of F w.r.t. µ and σ can be obtained with backpropagation. This allows to use any gradient descent method used for non-variational optimisation [2]

Overview of PyVarInf

The core feature of PyVarInf is the Variationalize function. Variationalize takes a model as input and outputs a variationalized version of the model with gaussian posterior.

Definition of a variational model

To define a variational model, first define a traditional PyTorch model, then use the Variationalize function :

import pyvarinf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.bn1 = nn.BatchNorm2d(10)
        self.bn2 = nn.BatchNorm2d(20)

    def forward(self, x):
        x = self.bn1(F.relu(F.max_pool2d(self.conv1(x), 2)))
        x = self.bn2(F.relu(F.max_pool2d(self.conv2(x), 2)))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)

model = Net()
var_model = pyvarinf.Variationalize(model)
var_model.cuda()

Optimisation of a variational model

Then, the var_model can be trained that way :

optimizer = optim.Adam(var_model.parameters(), lr=0.01)

def train(epoch):
    var_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = var_model(data)
        loss_error = F.nll_loss(output, target)
	# The model is only sent once, thus the division by
	# the number of datapoints used to train
        loss_prior = var_model.prior_loss() / 60000
        loss = loss_error + loss_prior
        loss.backward()
        optimizer.step()

for epoch in range(1, 500):
    train(epoch)

Available priors

In PyVarInf, we have implemented four families of priors :

Gaussian prior

The gaussian prior is N(0,Σ), with Σ the diagonal matrix diag(σ12, ..., σd2) defined such that 1/σi is the square root of the number of parameters in the layer, following the standard initialisation of neural network weights. It is the default prior, and do not have any parameter. It can be set with :

var_model.set_prior('gaussian')

Conjugate priors

The conjugate prior is used if we assume that all the weights in a given layer should be distributed as a gaussian, but with unknown mean and variance. See [6] for more details. This prior can be set with

var_model.set_prior('conjugate', n_mc_samples, alpha_0, beta_0, mu_0, kappa_0)

There are five parameters that have to bet set :

  • n_mc_samples, the number of samples used in the Monte Carlo estimation of the prior loss and its gradient.
  • mu_0, the prior sample mean
  • kappa_0, the number of samples used to estimate the prior sample mean
  • alpha_0 and beta_0, such that variance was estimated from 2 alpha_0 observations with sample mean mu_0 and sum of squared deviations 2 beta_0

Conjugate prior with known mean

The conjugate prior with known mean is similar to the conjugate prior. It is used if we assume that all the weights in a given layer should be distributed as a gaussian with a known mean but unknown variance. It is usefull in neural networks model when we assume that the weights in a layer should have mean 0. See [6] for more details. This prior can be set with :

var_model.set_prior('conjugate_known_mean', n_mc_samples, mean, alpha_0, beta_0)

Four parameters have to be set:

  • n_mc_samples, the number of samples used in the Monte Carlo estimation of the prior loss and its gradient.
  • mean, the known mean
  • alpha_0 and beta_0 defined as above

Mixture of two gaussian

The idea of using a mixture of two gaussians is defined in [1]. This prior can be set with:

var_model.set_prior('mixtgauss', n_mc_samples, sigma_1, sigma_2, pi)
  • n_mc_samples, the number of samples used in the Monte Carlo estimation of the prior loss and its gradient.
  • sigma_1 and sigma_2 the std of the two gaussians
  • pi the probability of the first gaussian

Requirements

This module requires Python 3. You need to have PyTorch installed for PyVarInf to work (as PyTorch is not readily available on PyPi). To install PyTorch, follow the instructions described here.

References

  • [1] Blundell, Charles, Cornebise, Julien, Kavukcuoglu, Koray, and Wierstra, Daan. Weight Uncertainty in Neural Networks. In International Conference on Machine Learning, pp. 1613–1622, 2015.
  • [2] Graves, Alex. Practical Variational Inference for Neural Networks. In Neural Information Processing Systems, 2011.
  • [3] Grünwald, Peter D. The Minimum Description Length principle. MIT press, 2007.
  • [4] Honkela, Antti and Valpola, Harri. Variational Learning and Bits-Back Coding: An Information-Theoretic View to Bayesian Learning. IEEE transactions on Neural Networks, 15(4), 2004.
  • [5] Hinton, Geoffrey E and Van Camp, Drew. Keeping Neural Networks Simple by Minimizing the Description Length of the Weights. In Proceedings of the sixth annual conference on Computational learning theory. ACM, 1993.
  • [6] Murphy, Kevin P. Conjugate Bayesian analysis of the Gaussian distribution., 2007.
Annotated, understandable, and visually interpretable PyTorch implementations of: VAE, BIRVAE, NSGAN, MMGAN, WGAN, WGANGP, LSGAN, DRAGAN, BEGAN, RaGAN, InfoGAN, fGAN, FisherGAN

Overview PyTorch 0.4.1 | Python 3.6.5 Annotated implementations with comparative introductions for minimax, non-saturating, wasserstein, wasserstein g

Shayne O'Brien 471 Dec 16, 2022
This is the official implementation for "Do Transformers Really Perform Bad for Graph Representation?".

Graphormer By Chengxuan Ying, Tianle Cai, Shengjie Luo, Shuxin Zheng*, Guolin Ke, Di He*, Yanming Shen and Tie-Yan Liu. This repo is the official impl

Microsoft 1.3k Dec 29, 2022
BarcodeRattler - A Raspberry Pi Powered Barcode Reader to load a game on the Mister FPGA using MBC

Barcode Rattler A Raspberry Pi Powered Barcode Reader to load a game on the Mist

Chrissy 29 Oct 31, 2022
Single cell current best practices tutorial case study for the paper:Luecken and Theis, "Current best practices in single-cell RNA-seq analysis: a tutorial"

Scripts for "Current best-practices in single-cell RNA-seq: a tutorial" This repository is complementary to the publication: M.D. Luecken, F.J. Theis,

Theis Lab 968 Dec 28, 2022
PyTorch implementation of ''Background Activation Suppression for Weakly Supervised Object Localization''.

Background Activation Suppression for Weakly Supervised Object Localization PyTorch implementation of ''Background Activation Suppression for Weakly S

35 Jan 06, 2023
Text2Art is an AI art generator powered with VQGAN + CLIP and CLIPDrawer models

Text2Art is an AI art generator powered with VQGAN + CLIP and CLIPDrawer models. You can easily generate all kind of art from drawing, painting, sketch, or even a specific artist style just using a t

Muhammad Fathy Rashad 643 Dec 30, 2022
Cereal box identification in store shelves using computer vision and a single train image per model.

Product Recognition on Store Shelves Description You can read the task description here. Report You can read and download our report here. Step A - Mu

Nicholas Baraghini 1 Jan 21, 2022
Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user who joins your server.

Discord-Protect Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user wh

Tir Omar 2 Oct 28, 2021
CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energy Management, 2020, PikaPika team

Citylearn Challenge This is the PyTorch implementation for PikaPika team, CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energ

bigAIdream projects 10 Oct 10, 2022
Jaxtorch (a jax nn library)

Jaxtorch (a jax nn library) This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular ja

nshepperd 17 Dec 08, 2022
Patches desktop steam to look like the new steamdeck ui.

steam_deck_ui_patch The Deck UI patch will patch the regular desktop steam to look like the brand new SteamDeck UI. This patch tool currently works on

The_IT_Dude 3 Aug 29, 2022
The source code of the paper "SHGNN: Structure-Aware Heterogeneous Graph Neural Network"

SHGNN: Structure-Aware Heterogeneous Graph Neural Network The source code and dataset of the paper: SHGNN: Structure-Aware Heterogeneous Graph Neural

Wentao Xu 7 Nov 13, 2022
Train robotic agents to learn pick and place with deep learning for vision-based manipulation in PyBullet.

Ravens is a collection of simulated tasks in PyBullet for learning vision-based robotic manipulation, with emphasis on pick and place. It features a Gym-like API with 10 tabletop rearrangement tasks,

Google Research 367 Jan 09, 2023
Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020.

RegNet Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020. Paper | Official Implementation RegNet offer a very

Vishal R 2 Feb 11, 2022
A Pytorch implement of paper "Anomaly detection in dynamic graphs via transformer" (TADDY).

TADDY: Anomaly detection in dynamic graphs via transformer This repo covers an reference implementation for the paper "Anomaly detection in dynamic gr

Yue Tan 21 Nov 24, 2022
Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods

Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods Introduction Graph Neural Networks (GNNs) have demonstrated

37 Dec 15, 2022
Tiny-NewsRec: Efficient and Effective PLM-based News Recommendation

Tiny-NewsRec The source codes for our paper "Tiny-NewsRec: Efficient and Effective PLM-based News Recommendation". Requirements PyTorch == 1.6.0 Tensor

Yang Yu 3 Dec 07, 2022
Code and data for the paper "Hearing What You Cannot See"

Hearing What You Cannot See: Acoustic Vehicle Detection Around Corners Public repository of the paper "Hearing What You Cannot See: Acoustic Vehicle D

TU Delft Intelligent Vehicles 26 Jul 13, 2022
Build a medical knowledge graph based on Unified Language Medical System (UMLS)

UMLS-Graph Build a medical knowledge graph based on Unified Language Medical System (UMLS) Requisite Install MySQL Server 5.6 and import UMLS data int

Donghua Chen 6 Dec 25, 2022
Volumetric parameterization of the placenta to a flattened template

placenta-flattening A MATLAB algorithm for volumetric mesh parameterization. Developed for mapping a placenta segmentation derived from an MRI image t

Mazdak Abulnaga 12 Mar 14, 2022