Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Overview

Bayes-Newton

Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Will Wilkinson.

Bayes-Newton provides a unifying view of approximate Bayesian inference, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.

Installation

pip install bayesnewton

Example

Given some inputs x and some data y, you can construct a Bayes-Newton model as follows,

kern = bayesnewton.kernels.Matern52()
lik = bayesnewton.likelihoods.Gaussian()
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

The training loop (inference and hyperparameter learning) is then set up using objax's built in functionality:

lr_adam = 0.1
lr_newton = 1
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())

@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton, **inf_args)  # perform inference and update variational params
    dE, E = energy(**inf_args)  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)  # update the hyperparameters
    return E

As we are using JAX, we can JIT compile the training loop:

train_op = objax.Jit(train_op)

and then run the training loop,

iters = 20
for i in range(1, iters + 1):
    loss = train_op()

Full demos are available here.

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.

Citing Bayes-Newton

@article{wilkinson2021bayesnewton,
  title = {{B}ayes-{N}ewton Methods for Approximate {B}ayesian Inference with {PSD} Guarantees},
  author = {Wilkinson, William J. and S\"arkk\"a, Simo and Solin, Arno},
  journal={arXiv preprint arXiv:2111.01721},
  year={2021}
}

Implemented Models

For a full list of the all the models available see the model class list.

Variational GPs

  • Variationl GP (Opper, Archambeau: The Variational Gaussian Approximation Revisited, Neural Computation 2009; Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to Inference in Conjugate Models, AISTATS 2017)
  • Sparse Variational GP (Hensman, Matthews, Ghahramani: Scalable Variational Gaussian Process Classification, AISTATS 2015; Adam, Chang, Khan, Solin: Dual Parameterization of Sparse Variational Gaussian Processes, NeurIPS 2021)
  • Markov Variational GP (Chang, Wilkinson, Khan, Solin: Fast Variational Learning in State Space Gaussian Process Models, MLSP 2020)
  • Sparse Markov Variational GP (Adam, Eleftheriadis, Durrande, Artemev, Hensman: Doubly Sparse Variational Gaussian Processes, AISTATS 2020; Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Spatio-Temporal Variational GP (Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)

Expectation Propagation GPs

  • Expectation Propagation GP (Minka: A Family of Algorithms for Approximate Bayesian Inference, Ph. D thesis 2000)
  • Sparse Expectation Propagation GP (energy not working) (Csato, Opper: Sparse on-line Gaussian processes, Neural Computation 2002; Bui, Yan, Turner: A Unifying Framework for Gaussian Process Pseudo Point Approximations Using Power Expectation Propagation, JMLR 2017)
  • Markov Expectation Propagation GP (Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Expectation Propagation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Laplace/Newton GPs

  • Laplace GP (Rasmussen, Williams: Gaussian Processes for Machine Learning, 2006)
  • Sparse Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Markov Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Sparse Markov Laplace GP

Linearisation GPs

  • Posterior Linearisation GP (García-Fernández, Tronarp, Sarkka: Gaussian Process Classification Using Posterior Linearization, IEEE Signal Processing 2019; Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Sparse Posterior Linearisation GP
  • Markov Posterior Linearisation GP (García-Fernández, Svensson, Sarkka: Iterated Posterior Linearization Smoother, IEEE Automatic Control 2016; Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Posterior Linearisation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Taylor Expansion / Analytical Linearisaiton GP (Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Markov Taylor GP / Extended Kalman Smoother (Bell: The Iterated Kalman Smoother as a Gauss-Newton method, SIAM Journal on Optimization 1994)
  • Sparse Taylor GP
  • Sparse Markov Taylor GP / Sparse Extended Kalman Smoother (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Gauss-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Gauss-Newton
  • Variational Gauss-Newton
  • PEP Gauss-Newton
  • 2nd-order PL Gauss-Newton

Quasi-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Quasi-Newton
  • Variational Quasi-Newton
  • PEP Quasi-Newton
  • PL Quasi-Newton

GPs with PSD Constraints via Riemannian Gradients

  • VI Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • Newton/Laplace Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • PEP Riemann Grad (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

Others

  • Infinite Horizon GP (Solin, Hensman, Turner: Infinite-Horizon Gaussian Processes, NeurIPS 2018)
  • Parallel Markov GP (with VI, EP, PL, ...) (Särkkä, García-Fernández: Temporal parallelization of Bayesian smoothers; Corenflos, Zhao, Särkkä: Gaussian Process Regression in Logarithmic Time; Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)
  • 2nd-order Posterior Linearisation GP (sparse, Markov, ...) (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
Comments
  • Addition of a Squared Exponential Kernel?

    Addition of a Squared Exponential Kernel?

    This is not an issue per se, but I was wondering if there was a specific reason that there wasn't a Squared Exponential kernel as part of the package?

    If applicable, I would be happy to submit a PR adding one.

    Just let me know.

    opened by mathDR 7
  • Add sq exponential kernel

    Add sq exponential kernel

    This PR does two things:

    1. Adds a Squared Exponential Kernel. Adopting the gpflow nomenclature for naming, i.e. SquaredExponential inheriting from StationaryKernel. As of now, only the K_r method is populated (others may be populated as needed).
    2. The marathon.py demo was changed to take the SquaredExponential kernel (with lengthscale = 40). This serves as a "test" to ensure that the kernel runs.
    opened by mathDR 2
  • question about the equation (64) in `Bayes-Newton` paper

    question about the equation (64) in `Bayes-Newton` paper

    I'm a little confused by the equation (64). It is calculated by equation (63), but where is the other term in equation (64)? such as denominator comes from the covariance of p(fn|u).

    opened by Fangwq 2
  • jitted predict

    jitted predict

    Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,

    Regards,

    opened by soldierofhell 2
  • error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    As the title said, there is an error after running heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method:

    File "heteroscedastic.py", line 101, in train_op
    model.inference(lr=lr_newton, damping=damping)  # perform inference and update variational params
    File "/BayesNewton-main/bayesnewton/inference.py",  line 871, in inference
    mean, jacobian, hessian, quasi_newton_state =self.update_variational_params(batch_ind, lr, **kwargs)
    File "/BayesNewton-main/bayesnewton/inference.py",
    line 1076, in update_variational_params
    jacobian_var = transpose(solve(omega, dmu_dv)) @ residual
    ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(117, 1, 1) and b=(117, 2, 2)
    

    Can you tell me where it is wrong ? Thanks in advance.

    opened by Fangwq 2
  • How to set initial `Pinf` variable in kernel?

    How to set initial `Pinf` variable in kernel?

    I note that the initial Pinf variable for Matern-5/2 kernel is as follows:

            Pinf = np.array([[self.variance,    0.0,   -kappa],
                             [0.0,    kappa, 0.0],
                             [-kappa, 0.0,   25.0*self.variance / self.lengthscale**4.0]])
    

    Why it is like that? Any references I should follow up?

    PS: by the way, the data ../data/aq_data.csv is missing.

    opened by Fangwq 2
  • How to install Newt in conda virtual environment?

    How to install Newt in conda virtual environment?

    Hi, thank you for sharing your great work.

    I am a little confused about how to install Newt in a conda VE. I really appreciate it if you could guide in this regard. Thank you

    opened by mohammad-saber 2
  • How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    Just as the title said, how can I understand cavity_distribution_tied in file basemodels.py? Is there any reference I should follow up? And I note that this code is similar to equation (64) in the BayesNewton paper. How does it come from?

    opened by Fangwq 1
  • issue with SparseVariationalGP method

    issue with SparseVariationalGP method

    When I run the code file demos/regression.py with SparseVariationalGP, something wrong happens:

    AssertionError: Assignments to variable must be an instance of JaxArray, but received f<class 'numpy.ndarray'>.
    

    It seems that a mistake in method SparseVariationalGP . Can you help to fix the problem? Thank you very much!

    opened by Fangwq 1
  • Sparse EP energy is incorrect

    Sparse EP energy is incorrect

    The current implementation of the sparse EP energy is not giving sensible results. This is a reminder to look into the reasons why and check against implementations elsewhere. PRs very welcome for this issue.

    Note: the EP energy is correct for all other models (GP, Markov GP, SparseMarkovGP)

    bug 
    opened by wil-j-wil 1
  • Double Precision Issues

    Double Precision Issues

    Hi!

    Many thanks for open-sourcing this package.

    I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

    • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
    • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

    However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

    Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

    Software and hardware details:

    objax==1.6.0
    jax==0.3.13
    jaxlib==0.3.10+cuda11.cudnn805
    
    NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
    GeForce RTX 3090 GPUs
    

    Thanks in advance.

    Best, Harrison

    opened by harrisonzhu508 1
  • Cannot run demo, possible incompatibility with latest Jax

    Cannot run demo, possible incompatibility with latest Jax

    Dear all,

    I am trying to run the demo examples, but I run in the following error


    ImportError Traceback (most recent call last) Input In [22], in <cell line: 1>() ----> 1 import bayesnewton 2 import objax 3 import numpy as np

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in ----> 1 from . import ( 2 kernels, 3 utils, 4 ops, 5 likelihoods, 6 models, 7 basemodels, 8 inference, 9 cubature 10 ) 13 def build_model(model, inf, name='GPModel'): 14 return type(name, (inf, model), {})

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in 3 import jax.numpy as np 4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm ----> 5 from jax.ops import index_add, index 6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix 7 from warnings import warn

    ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)

    I think its related to this from the Jax website:

    The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

    opened by daniel-trejobanos 3
  • Latest versions of JAX and objax cause compile slow down

    Latest versions of JAX and objax cause compile slow down

    It is recommended to use the following versions of jax and objax:

    jax==0.2.9
    jaxlib==0.1.60
    objax==1.3.1
    

    This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.

    bug 
    opened by wil-j-wil 0
Releases(v1.2.0)
Owner
AaltoML
Machine learning group at Aalto University lead by Prof. Solin
AaltoML
Official Implementation of SWAGAN: A Style-based Wavelet-driven Generative Model

Official Implementation of SWAGAN: A Style-based Wavelet-driven Generative Model SWAGAN: A Style-based Wavelet-driven Generative Model Rinon Gal, Dana

55 Dec 06, 2022
Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21)

NeuralGIF Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21) We present Neural Generalized Implicit F

Garvita Tiwari 104 Nov 18, 2022
Contains code for Deep Kernelized Dense Geometric Matching

DKM - Deep Kernelized Dense Geometric Matching Contains code for Deep Kernelized Dense Geometric Matching We provide pretrained models and code for ev

Johan Edstedt 83 Dec 23, 2022
MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

Burak Bagatarhan 12 Mar 29, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

24 Dec 31, 2022
SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model

SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model Edresson Casanova, Christopher Shulby, Eren Gölge, Nicolas Michael Müller, Frede

Edresson Casanova 92 Dec 09, 2022
Given a 2D triangle mesh, we could randomly generate cloud points that fill in the triangle mesh

generate_cloud_points Given a 2D triangle mesh, we could randomly generate cloud points that fill in the triangle mesh. Run python disp_mesh.py Or you

Peng Yu 2 Dec 24, 2021
Junction Tree Variational Autoencoder for Molecular Graph Generation (ICML 2018)

Junction Tree Variational Autoencoder for Molecular Graph Generation Official implementation of our Junction Tree Variational Autoencoder https://arxi

Wengong Jin 418 Jan 07, 2023
arxiv-sanity, but very lite, simply providing the core value proposition of the ability to tag arxiv papers of interest and have the program recommend similar papers.

arxiv-sanity, but very lite, simply providing the core value proposition of the ability to tag arxiv papers of interest and have the program recommend similar papers.

Andrej 671 Dec 31, 2022
Teaching end to end workflow of deep learning

Deep-Education This repository is now available for public use for teaching end to end workflow of deep learning. This implies that learners/researche

Data Lab at College of William and Mary 2 Sep 26, 2022
Source code for TACL paper "KEPLER: A Unified Model for Knowledge Embedding and Pre-trained Language Representation".

KEPLER: A Unified Model for Knowledge Embedding and Pre-trained Language Representation Source code for TACL 2021 paper KEPLER: A Unified Model for Kn

THU-KEG 138 Dec 22, 2022
Forecasting with Gradient Boosted Time Series Decomposition

ThymeBoost ThymeBoost combines time series decomposition with gradient boosting to provide a flexible mix-and-match time series framework for spicy fo

131 Jan 08, 2023
This repository contains the code for "SBEVNet: End-to-End Deep Stereo Layout Estimation" paper by Divam Gupta, Wei Pu, Trenton Tabor, Jeff Schneider

SBEVNet: End-to-End Deep Stereo Layout Estimation This repository contains the code for "SBEVNet: End-to-End Deep Stereo Layout Estimation" paper by D

Divam Gupta 19 Dec 17, 2022
Implements Gradient Centralization and allows it to use as a Python package in TensorFlow

Gradient Centralization TensorFlow This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique

Rishit Dagli 101 Nov 01, 2022
Bulk2Space is a spatial deconvolution method based on deep learning frameworks

Bulk2Space Spatially resolved single-cell deconvolution of bulk transcriptomes using Bulk2Space Bulk2Space is a spatial deconvolution method based on

Dr. FAN, Xiaohui 60 Dec 27, 2022
CRNN With PyTorch

CRNN-PyTorch Implementation of https://arxiv.org/abs/1507.05717

Vadim 4 Sep 01, 2022
Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems

Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems This is our experimental code for RecSys 2021 paper "Learning

11 Jul 28, 2022
Deep Crop Rotation

Deep Crop Rotation Paper (to come very soon!) We propose a deep learning approach to modelling both inter- and intra-annual patterns for parcel classi

Félix Quinton 5 Sep 23, 2022
SciPy fixes and extensions

scipyx SciPy is large library used everywhere in scientific computing. That's why breaking backwards-compatibility comes as a significant cost and is

Nico Schlömer 16 Jul 17, 2022
[ICCV2021] Official Pytorch implementation for SDGZSL (Semantics Disentangling for Generalized Zero-Shot Learning)

Semantics Disentangling for Generalized Zero-shot Learning This is the official implementation for paper Zhi Chen, Yadan Luo, Ruihong Qiu, Zi Huang, J

25 Dec 06, 2022