Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

Overview

PyTorch Implementation of Differentiable ODE Solvers

This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. For usage of ODE solvers in deep learning applications, see reference [1].

As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.

Installation

To install latest stable version:

pip install torchdiffeq

To install latest on GitHub:

pip install git+https://github.com/rtqichen/torchdiffeq

Examples

Examples are placed in the examples directory.

We encourage those who are interested in using this library to take a look at examples/ode_demo.py for understanding how to use torchdiffeq to fit a simple spiral ODE.

ODE Demo

Basic usage

This library provides one main interface odeint which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value,

dy/dt = f(t, y)    y(t_0) = y_0.

The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.

To solve an IVP using the default solver:

from torchdiffeq import odeint

odeint(func, y0, t)

where func is any callable implementing the ordinary differential equation f(t, x), y0 is an any-D Tensor representing the initial values, and t is a 1-D Tensor containing the evaluation points. The initial time is taken to be t[0].

Backpropagation through odeint goes through the internals of the solver. Note that this is not numerically stable for all solvers (but should probably be fine with the default dopri5 method). Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage.

To use the adjoint method:

from torchdiffeq import odeint_adjoint as odeint

odeint(func, y0, t)

odeint_adjoint simply wraps around odeint, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call.

The biggest gotcha is that func must be a nn.Module when using the adjoint method. This is used to collect parameters of the differential equation.

Differentiable event handling

We allow terminating an ODE solution based on an event function. Backpropagation through most solvers is supported. For usage of event handling in deep learning applications, see reference [2].

This can be invoked with odeint_event:

from torchdiffeq import odeint_event
odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs)
  • func and y0 are the same as odeint.
  • t0 is a scalar representing the initial time value.
  • event_fn(t, y) returns a tensor, and is a required keyword argument.
  • reverse_time is a boolean specifying whether we should solve in reverse time. Default is False.
  • odeint_interface is one of odeint or odeint_adjoint, specifying whether adjoint mode should be used for differentiating through the ODE solution. Default is odeint.
  • **kwargs: any remaining keyword arguments are passed to odeint_interface.

The solve is terminated at an event time t and state y when an element of event_fn(t, y) is equal to zero. Multiple outputs from event_fn can be used to specify multiple event functions, of which the first to trigger will terminate the solve.

Both the event time and final state are returned from odeint_event, and can be differentiated. Gradients will be backpropagated through the event function.

The numerical precision for the event time is determined by the atol argument.

See example of simulating and differentiating through a bouncing ball in examples/bouncing_ball.py.

Bouncing Ball

Keyword arguments for odeint(_adjoint)

Keyword arguments:

  • rtol Relative tolerance.
  • atol Absolute tolerance.
  • method One of the solvers listed below.
  • options A dictionary of solver-specific options, see the further documentation.

List of ODE Solvers:

Adaptive-step:

  • dopri8 Runge-Kutta of order 8 of Dormand-Prince-Shampine.
  • dopri5 Runge-Kutta of order 5 of Dormand-Prince-Shampine [default].
  • bosh3 Runge-Kutta of order 3 of Bogacki-Shampine.
  • fehlberg2 Runge-Kutta-Fehlberg of order 2.
  • adaptive_heun Runge-Kutta of order 2.

Fixed-step:

  • euler Euler method.
  • midpoint Midpoint method.
  • rk4 Fourth-order Runge-Kutta with 3/8 rule.
  • explicit_adams Explicit Adams-Bashforth.
  • implicit_adams Implicit Adams-Bashforth-Moulton.

Additionally, all solvers available through SciPy are wrapped for use with scipy_solver.

For most problems, good choices are the default dopri5, or to use rk4 with options=dict(step_size=...) set appropriately small. Adjusting the tolerances (adaptive solvers) or step size (fixed solvers), will allow for trade-offs between speed and accuracy.

Frequently Asked Questions

Take a look at our FAQ for frequently asked questions.

Further documentation

For details of the adjoint-specific and solver-specific options, check out the further documentation.

References

Applications of differentiable ODE solvers and event handling are discussed in these two papers:

[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." Advances in Neural Information Processing Systems. 2018. [arxiv]

[2] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." International Conference on Learning Representations. 2021. [arxiv]


If you found this library useful in your research, please consider citing.

@article{chen2018neuralode,
  title={Neural Ordinary Differential Equations},
  author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David},
  journal={Advances in Neural Information Processing Systems},
  year={2018}
}

@article{chen2021eventfn,
  title={Learning Neural Event Functions for Ordinary Differential Equations},
  author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian},
  journal={International Conference on Learning Representations},
  year={2021}
}
Owner
Ricky Chen
Ricky Chen
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
ocaml-torch provides some ocaml bindings for the PyTorch tensor library.

ocaml-torch provides some ocaml bindings for the PyTorch tensor library. This brings to OCaml NumPy-like tensor computations with GPU acceleration and tape-based automatic differentiation.

Laurent Mazare 369 Jan 03, 2023
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
On the Variance of the Adaptive Learning Rate and Beyond

RAdam On the Variance of the Adaptive Learning Rate and Beyond We are in an early-release beta. Expect some adventures and rough edges. Table of Conte

Liyuan Liu 2.5k Dec 27, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

878 Dec 30, 2022
Riemannian Adaptive Optimization Methods with pytorch optim

geoopt Manifold aware pytorch.optim. Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more. Installation Make sur

642 Jan 03, 2023
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

TorchSSL: A PyTorch-based Toolbox for Semi-Supervised Learning

1k Dec 28, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Lambda Networks - Pytorch Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ l

Phil Wang 1.5k Jan 07, 2023
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023