A Python library for differentiable optimal control on accelerators.

Related tags

Deep Learningtrajax
Overview

trajax

A Python library for differentiable optimal control on accelerators.

Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:

  1. Are automatically efficiently differentiable, via jax.grad.
  2. Scale up to parallel instances via jax.vmap and jax.pmap.
  3. Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with jax.jit.
  4. Are made available from Python, written with NumPy.

In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad) to differentiate functions that call a Trajax solver.

This is a research project, not an official Google product.

Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.

Trajectory optimization and optimal control

We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c, dynamics function f, and initial state x0, the goal is to compute:

argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))

subject to the constraint that X[0] == x0 and that:

all(X[t + 1] == f(X[t], U[t], t) for t in range(T))

There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.

API

In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …) as a type denoted F[a, b, …]. Assume n is the state dimension, d is the control dimension, and T is the time horizon.

Problem setup convention/signature

Setting up a problem requires writing two functions, cost and dynamics, with type signatures:

cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]

Note that even if a dimension n or d is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).

Because Trajax uses JAX, the cost and dynamics functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy in place of numpy are likely to work.

Solvers

If we abbreviate the type of the above two functions as CostFn and DynamicsFn, then our solvers have the following type signature prefix in common:

solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput

SolverOutput is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs). The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).

There are currently four solvers provided: ilqr, scipy_minimize, cem, and random_shooting. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.

Underlying the ilqr implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr.

Objectives

One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:

objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float

Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:

def improve_controls(cost, dynamics, U, x0, eta, num_iters):
  grad_fn = jax.grad(trajax.objective, argnums=(2,))
  for i in range(num_iters):
    U = U - eta * grad_fn(cost, dynamics, U, x0)
  return U

The solvers provided by Trajax are actually built around this objective function. For instance, the scipy_minimize solver simply calls scipy.minimize.minimize with the gradient and Hessian-vector product functions derived from objective using jax.grad and jax.hessian.

Limitations

​​Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy in place of standard numpy, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.

Owner
Google
Google ❤️ Open Source
Google
The mini-MusicNet dataset

mini-MusicNet A music-domain dataset for multi-label classification Music transcription is sequence-to-sequence prediction problem: given an audio per

John Thickstun 4 Nov 09, 2022
So-ViT: Mind Visual Tokens for Vision Transformer

So-ViT: Mind Visual Tokens for Vision Transformer        Introduction This repository contains the source code under PyTorch framework and models trai

Jiangtao Xie 44 Nov 24, 2022
Normalization Calibration (NorCal) for Long-Tailed Object Detection and Instance Segmentation

NorCal Normalization Calibration (NorCal) for Long-Tailed Object Detection and Instance Segmentation On Model Calibration for Long-Tailed Object Detec

Tai-Yu (Daniel) Pan 24 Dec 25, 2022
Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

Kim Seonghyeon 2.2k Jan 01, 2023
PyTorch - Python + Nim

Master Release Pytorch - Py + Nim A Nim frontend for pytorch, aiming to be mostly auto-generated and internally using ATen. Because Nim compiles to C+

Giovanni Petrantoni 425 Dec 22, 2022
Toontown: Galaxy, a new Toontown game based on Disney's Toontown Online

Toontown: Galaxy The official archive repo for Toontown: Galaxy, a new Toontown

1 Feb 15, 2022
Scale-aware Automatic Augmentation for Object Detection (CVPR 2021)

SA-AutoAug Scale-aware Automatic Augmentation for Object Detection Yukang Chen, Yanwei Li, Tao Kong, Lu Qi, Ruihang Chu, Lei Li, Jiaya Jia [Paper] [Bi

DV Lab 182 Dec 29, 2022
Library for implementing reservoir computing models (echo state networks) for multivariate time series classification and clustering.

Framework overview This library allows to quickly implement different architectures based on Reservoir Computing (the family of approaches popularized

Filippo Bianchi 249 Dec 21, 2022
An open source implementation of CLIP.

OpenCLIP Welcome to an open source implementation of OpenAI's CLIP (Contrastive Language-Image Pre-training). The goal of this repository is to enable

2.7k Dec 31, 2022
Pytorch implementation of the unsupervised object discovery method LOST.

LOST Pytorch implementation of the unsupervised object discovery method LOST. More details can be found in the paper: Localizing Objects with Self-Sup

Valeo.ai 189 Dec 25, 2022
Fast Neural Representations for Direct Volume Rendering

Fast Neural Representations for Direct Volume Rendering Sebastian Weiss, Philipp Hermüller, Rüdiger Westermann This repository contains the code and s

Sebastian Weiss 20 Dec 03, 2022
A vision library for performing sliced inference on large images/small objects

SAHI: Slicing Aided Hyper Inference A vision library for performing sliced inference on large images/small objects Overview Object detection and insta

Open Business Software Solutions 2.3k Jan 04, 2023
This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing described in paper Discontinuous Grammar as a Foreign Language.

Discontinuous Grammar as a Foreign Language This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing

Daniel Fernández-González 2 Apr 07, 2022
Multimodal Temporal Context Network (MTCN)

Multimodal Temporal Context Network (MTCN) This repository implements the model proposed in the paper: Evangelos Kazakos, Jaesung Huh, Arsha Nagrani,

Evangelos Kazakos 13 Nov 24, 2022
Official Pytorch Implementation of Length-Adaptive Transformer (ACL 2021)

Length-Adaptive Transformer This is the official Pytorch implementation of Length-Adaptive Transformer. For detailed information about the method, ple

Clova AI Research 93 Dec 28, 2022
Tools for robust generative diffeomorphic slice to volume reconstruction

RGDSVR Tools for Robust Generative Diffeomorphic Slice to Volume Reconstructions (RGDSVR) This repository provides tools to implement the methods in t

Lucilio Cordero-Grande 0 Oct 29, 2021
Attention over nodes in Graph Neural Networks using PyTorch (NeurIPS 2019)

Intro This repository contains code to generate data and reproduce experiments from our NeurIPS 2019 paper: Boris Knyazev, Graham W. Taylor, Mohamed R

Boris Knyazev 242 Jan 06, 2023
Punctuation Restoration using Transformer Models for High-and Low-Resource Languages

Punctuation Restoration using Transformer Models This repository contins official implementation of the paper Punctuation Restoration using Transforme

Tanvirul Alam 142 Jan 01, 2023
This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object Tracking with TRansformer.

MOTR: End-to-End Multiple-Object Tracking with TRansformer This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object

348 Jan 07, 2023
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

23 Nov 11, 2022