Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Overview

Infinitely Deep Bayesian Neural Networks with SDEs

This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.

Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.

Installation

This library runs on jax==0.1.77 and torch==1.6.0. To install all other requirements:

pip install -r requirements.txt

Note: Package versions may change, refer to official JAX installation instructions here.

JaxSDE: Differentiable SDE Solvers in JAX

The jaxsde library contains SDE solvers in the Ito and Stratonovich form. Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun} (strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE). Stochastic adjoint (sdeint_ito) training mode does not work efficiently yet, use sdeint_ito_fixed_grid for now. Tradeoff solver speed for precision during training or inference by adjusting --nsteps <# steps>.

Usage

Default solver: Backpropagation through the solver.

from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid

y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")

Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.

from jaxsde.jaxsde.sdeint import sdeint_ito

y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")

Brax: Bayesian SDE Framework in JAX

Implementation of composable Bayesian layers in the stax API. Our SDE Bayesian layers can be used with the SDEBNN block composed with multiple parameterizations of time-dependent layers in diffeq_layers. Sticking-the-landing (STL) trick can be enabled during training with --stl for improving convergence rate. Augment the inputs by a custom amount --aug <integer>, set the number of samples averaged over with --nsamples <integer>. If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad and gradient checkpointing: --remat.

Samples from SDEBNN-learned predictive prior and posterior density distributions.

Usage

All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn instead).

Toy 1D regression to learn complex posteriors:

python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou

Image Classification:

To train an SDEBNN model:

python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1

To train a ResNet baseline, specify --model resnet and for a Bayesian ResNet baseline, specify --meanfield_sdebnn.

TorchBNN: SDE-BNN in Pytorch

A PyTorch implementation of the Brax framework powered by the torchsde backend.

Usage

All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.

Toy 1D regression to learn multi-modal posterior:

python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>

Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.

Image Classification:

All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True if training with adjoint and adaptive modes).

python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True

References

[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]

[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]


If you found this library useful in your research, please consider citing

@article{xu2021sdebnn,
  title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
  author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
  archivePrefix = {arXiv},
  year={2021}
}
Owner
Winnie Xu
Undergrad in CS/Stats/Math '22 @ UToronto. Working on something secret @cohere-ai. Deep neural networks @for-ai @VectorInstitute. Prev. @google-research @NVIDIA
Winnie Xu
The modify PyTorch version of Siam-trackers which are speed-up by TensorRT.

SiamTracker-with-TensorRT The modify PyTorch version of Siam-trackers which are speed-up by TensorRT or ONNX. [Updating...] Examples demonstrating how

9 Dec 13, 2022
Code for our paper Domain Adaptive Semantic Segmentation with Self-Supervised Depth Estimation

CorDA Code for our paper Domain Adaptive Semantic Segmentation with Self-Supervised Depth Estimation Prerequisite Please create and activate the follo

Qin Wang 60 Nov 30, 2022
A small library for creating and manipulating custom JAX Pytree classes

Treeo A small library for creating and manipulating custom JAX Pytree classes Light-weight: has no dependencies other than jax. Compatible: Treeo Tree

Cristian Garcia 58 Nov 23, 2022
Python Library for Signal/Image Data Analysis with Transport Methods

PyTransKit Python Transport Based Signal Processing Toolkit Website and documentation: https://pytranskit.readthedocs.io/ Installation The library cou

24 Dec 23, 2022
The code for paper "Contrastive Spatio-Temporal Pretext Learning for Self-supervised Video Representation" which is accepted by AAAI 2022

Contrastive Spatio Temporal Pretext Learning for Self-supervised Video Representation (AAAI 2022) The code for paper "Contrastive Spatio-Temporal Pret

8 Jun 30, 2022
Paddle Graph Learning (PGL) is an efficient and flexible graph learning framework based on PaddlePaddle

DOC | Quick Start | δΈ­ζ–‡ Breaking News !! πŸ”₯ πŸ”₯ πŸ”₯ OGB-LSC KDD CUP 2021 winners announced!! (2021.06.17) Super excited to announce our PGL team won TWO

1.5k Jan 06, 2023
The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate.

The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate. Website β€’ Key Features β€’ How To Use β€’ Docs β€’

Pytorch Lightning 21.1k Jan 08, 2023
Official codes for the paper "Learning Hierarchical Discrete Linguistic Units from Visually-Grounded Speech"

ResDAVEnet-VQ Official PyTorch implementation of Learning Hierarchical Discrete Linguistic Units from Visually-Grounded Speech What is in this repo? M

Wei-Ning Hsu 21 Aug 23, 2022
[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax

[NeurIPS 2021] Galerkin Transformer: linear attention without softmax Summary A non-numerical analyst oriented explanation on Toward Data Science abou

Shuhao Cao 159 Dec 20, 2022
[TIP 2020] Multi-Temporal Scene Classification and Scene Change Detection with Correlation based Fusion

Multi-Temporal Scene Classification and Scene Change Detection with Correlation based Fusion Code for Multi-Temporal Scene Classification and Scene Ch

Lixiang Ru 33 Dec 12, 2022
Google-drive-to-sqlite - Create a SQLite database containing metadata from Google Drive

google-drive-to-sqlite Create a SQLite database containing metadata from Google

Simon Willison 140 Dec 04, 2022
Only works with the dashboard version / branch of jesse

Jesse optuna Only works with the dashboard version / branch of jesse. The config.yml should be self-explainatory. Installation # install from git pip

Markus K. 8 Dec 04, 2022
An implementation for `Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction`

Text2Event An implementation for Text2Event: Controllable Sequence-to-Structure Generation for End-to-end Event Extraction Please contact Yaojie Lu (@

Roger 153 Jan 07, 2023
A FAIR dataset of TCV experimental results for validating edge/divertor turbulence models.

TCV-X21 validation for divertor turbulence simulations Quick links Intro Welcome to TCV-X21. We're glad you've found us! This repository is designed t

0 Dec 18, 2021
Fast, Attemptable Route Planner for Navigation in Known and Unknown Environments

FAR Planner uses a dynamically updated visibility graph for fast replanning. The planner models the environment with polygons and builds a global visi

Fan Yang 346 Dec 30, 2022
Official PyTorch Implementation of HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning (NeurIPS 2021 Spotlight)

[NeurIPS 2021 Spotlight] HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning [Paper] This is Official PyTorch implementatio

42 Nov 01, 2022
Action Segmentation Evaluation

Reference Action Segmentation Evaluation Code This repository contains the reference code for action segmentation evaluation. If you have a bug-fix/im

5 May 22, 2022
A working implementation of the Categorical DQN (Distributional RL).

Categorical DQN. Implementation of the Categorical DQN as described in A distributional Perspective on Reinforcement Learning. Thanks to @tudor-berari

Florin Gogianu 98 Sep 20, 2022
PyTorch implementation of the YOLO (You Only Look Once) v2

PyTorch implementation of the YOLO (You Only Look Once) v2 The YOLOv2 is one of the most popular one-stage object detector. This project adopts PyTorc

η”³η‘žη‰ (Ruimin Shen) 433 Nov 24, 2022
Implementations of polygamma, lgamma, and beta functions for PyTorch

lgamma Implementations of polygamma, lgamma, and beta functions for PyTorch. It's very hacky, but that's usually ok for research use. To build, run: .

Rachit Singh 24 Nov 09, 2021