Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Overview

Infinitely Deep Bayesian Neural Networks with SDEs

This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.

Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.

Installation

This library runs on jax==0.1.77 and torch==1.6.0. To install all other requirements:

pip install -r requirements.txt

Note: Package versions may change, refer to official JAX installation instructions here.

JaxSDE: Differentiable SDE Solvers in JAX

The jaxsde library contains SDE solvers in the Ito and Stratonovich form. Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun} (strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE). Stochastic adjoint (sdeint_ito) training mode does not work efficiently yet, use sdeint_ito_fixed_grid for now. Tradeoff solver speed for precision during training or inference by adjusting --nsteps <# steps>.

Usage

Default solver: Backpropagation through the solver.

from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid

y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")

Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.

from jaxsde.jaxsde.sdeint import sdeint_ito

y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")

Brax: Bayesian SDE Framework in JAX

Implementation of composable Bayesian layers in the stax API. Our SDE Bayesian layers can be used with the SDEBNN block composed with multiple parameterizations of time-dependent layers in diffeq_layers. Sticking-the-landing (STL) trick can be enabled during training with --stl for improving convergence rate. Augment the inputs by a custom amount --aug <integer>, set the number of samples averaged over with --nsamples <integer>. If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad and gradient checkpointing: --remat.

Samples from SDEBNN-learned predictive prior and posterior density distributions.

Usage

All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn instead).

Toy 1D regression to learn complex posteriors:

python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou

Image Classification:

To train an SDEBNN model:

python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1

To train a ResNet baseline, specify --model resnet and for a Bayesian ResNet baseline, specify --meanfield_sdebnn.

TorchBNN: SDE-BNN in Pytorch

A PyTorch implementation of the Brax framework powered by the torchsde backend.

Usage

All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.

Toy 1D regression to learn multi-modal posterior:

python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>

Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.

Image Classification:

All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True if training with adjoint and adaptive modes).

python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True

References

[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]

[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]


If you found this library useful in your research, please consider citing

@article{xu2021sdebnn,
  title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
  author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
  archivePrefix = {arXiv},
  year={2021}
}
Owner
Winnie Xu
Undergrad in CS/Stats/Math '22 @ UToronto. Working on something secret @cohere-ai. Deep neural networks @for-ai @VectorInstitute. Prev. @google-research @NVIDIA
Winnie Xu
Official implementation of the NeurIPS 2021 paper Online Learning Of Neural Computations From Sparse Temporal Feedback

Online Learning Of Neural Computations From Sparse Temporal Feedback This repository is the official implementation of the NeurIPS 2021 paper Online L

Lukas Braun 3 Dec 15, 2021
Watch faces morph into each other with StyleGAN 2, StyleGAN, and DCGAN!

FaceMorpher FaceMorpher is an innovative project to get a unique face morph (or interpolation for geeks) on a website. Yes, this means you can see fac

Anish 9 Jun 24, 2022
Object DGCNN and DETR3D, Our implementations are built on top of MMdetection3D.

This repo contains the implementations of Object DGCNN (https://arxiv.org/abs/2110.06923) and DETR3D (https://arxiv.org/abs/2110.06922). Our implementations are built on top of MMdetection3D.

Wang, Yue 539 Jan 07, 2023
A mini-course offered to Undergrad chemistry students

The best way to use this material is by forking it by click the Fork button at the top, right corner. Then you will get your own copy to play with! Th

Raghu 19 Dec 19, 2022
Pytorch code for "DPFM: Deep Partial Functional Maps" - 3DV 2021 (Oral)

DPFM Code for "DPFM: Deep Partial Functional Maps" - 3DV 2021 (Oral) Installation This implementation runs on python = 3.7, use pip to install depend

Souhaib Attaiki 29 Oct 03, 2022
Weakly Supervised Text-to-SQL Parsing through Question Decomposition

Weakly Supervised Text-to-SQL Parsing through Question Decomposition The official repository for the paper "Weakly Supervised Text-to-SQL Parsing thro

14 Dec 19, 2022
An official source code for paper Deep Graph Clustering via Dual Correlation Reduction, accepted by AAAI 2022

Dual Correlation Reduction Network An official source code for paper Deep Graph Clustering via Dual Correlation Reduction, accepted by AAAI 2022. Any

yueliu1999 109 Dec 23, 2022
MLJetReconstruction - using machine learning to reconstruct jets for CMS

MLJetReconstruction - using machine learning to reconstruct jets for CMS The C++ data extraction code used here was based heavily on that foundv here.

ALPhA Davidson 0 Nov 17, 2021
This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation).

FlatGCN This is the official Pytorch-version code of FlatGCN (Flattened Graph Convolutional Networks for Recommendation, submitted to ICASSP2022). Req

Dreamer 2 Aug 09, 2022
Shape Matching of Real 3D Object Data to Synthetic 3D CADs (3DV project @ ETHZ)

Real2CAD-3DV Shape Matching of Real 3D Object Data to Synthetic 3D CADs (3DV project @ ETHZ) Group Member: Yue Pan, Yuanwen Yue, Bingxin Ke, Yujie He

24 Jun 22, 2022
Gin provides a lightweight configuration framework for Python

Gin Config Authors: Dan Holtmann-Rice, Sergio Guadarrama, Nathan Silberman Contributors: Oscar Ramirez, Marek Fiser Gin provides a lightweight configu

Google 1.7k Jan 03, 2023
Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease

Heart_Disease_Classification Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease Dataset

Ashish 1 Jan 30, 2022
UFT - Universal File Transfer With Python

UFT 2.0.0 UFT (Universal File Transfer) is a CLI tool , which can be used to upl

Merwin 1 Feb 18, 2022
Official implementation of the paper Visual Parser: Representing Part-whole Hierarchies with Transformers

Visual Parser (ViP) This is the official implementation of the paper Visual Parser: Representing Part-whole Hierarchies with Transformers. Key Feature

Shuyang Sun 117 Dec 11, 2022
Only valid pull requests will be allowed. Use python only and readme changes will not be accepted.

❌ This repo is excluded from hacktoberfest This repo is for python beginners and contains lot of beginner python projects for practice. You can also s

Prajjwal Pathak 50 Dec 28, 2022
PyTorch reimplementation of REALM and ORQA

PyTorch reimplementation of REALM and ORQA

Li-Huai (Allan) Lin 17 Aug 20, 2022
🚩🚩🚩

My CTF Challenges 2021 AIS3 Pre-exam / MyFirstCTF Name Category Keywords Difficulty ⒸⓄⓋⒾⒹ-①⑨ (MyFirstCTF Only) Reverse Baby ★ Piano Reverse C#, .NET ★

6 Oct 28, 2021
An Open Source Machine Learning Framework for Everyone

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

170.1k Jan 05, 2023
PyTorch wrappers for using your model in audacity!

audacitorch This package contains utilities for prepping PyTorch audio models for use in Audacity. More specifically, it provides abstract classes for

Hugo Flores García 130 Dec 14, 2022
Image Segmentation with U-Net Algorithm on Carvana Dataset using AWS Sagemaker

Image Segmentation with U-Net Algorithm on Carvana Dataset using AWS Sagemaker This is a full project of image segmentation using the model built with

Htin Aung Lu 1 Jan 04, 2022