tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
The official implementation of EIGNN: Efficient Infinite-Depth Graph Neural Networks (NeurIPS 2021)

EIGNN: Efficient Infinite-Depth Graph Neural Networks The official implementation of EIGNN: Efficient Infinite-Depth Graph Neural Networks (NeurIPS 20

Juncheng Liu 14 Nov 22, 2022
Qimera: Data-free Quantization with Synthetic Boundary Supporting Samples

Qimera: Data-free Quantization with Synthetic Boundary Supporting Samples This repository is the official implementation of paper [Qimera: Data-free Q

Kanghyun Choi 21 Nov 03, 2022
Code for CVPR2021 "Visualizing Adapted Knowledge in Domain Transfer". Visualization for domain adaptation. #explainable-ai

Visualizing Adapted Knowledge in Domain Transfer @inproceedings{hou2021visualizing, title={Visualizing Adapted Knowledge in Domain Transfer}, auth

Yunzhong Hou 80 Dec 25, 2022
Text to image synthesis using thought vectors

Text To Image Synthesis Using Thought Vectors This is an experimental tensorflow implementation of synthesizing images from captions using Skip Though

Paarth Neekhara 2.1k Jan 05, 2023
Code for BMVC2021 "MOS: A Low Latency and Lightweight Framework for Face Detection, Landmark Localization, and Head Pose Estimation"

MOS-Multi-Task-Face-Detect Introduction This repo is the official implementation of "MOS: A Low Latency and Lightweight Framework for Face Detection,

104 Dec 08, 2022
Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception, IROS 2021

For academic use only. Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception Ziwei Wang, Liyuan Pan, Yonhon Ng, Zheyu Zhuang and Robert Mahony Th

Ziwei Wang 11 Jan 04, 2023
TVNet: Temporal Voting Network for Action Localization

TVNet: Temporal Voting Network for Action Localization This repo holds the codes of paper: "TVNet: Temporal Voting Network for Action Localization". P

hywang 5 Jul 26, 2022
Pytorch Implementation of Adversarial Deep Network Embedding for Cross-Network Node Classification

Pytorch Implementation of Adversarial Deep Network Embedding for Cross-Network Node Classification (ACDNE) This is a pytorch implementation of the Adv

陈志豪 8 Oct 13, 2022
Source code for the paper: Variance-Aware Machine Translation Test Sets (NeurIPS 2021 Datasets and Benchmarks Track)

Variance-Aware-MT-Test-Sets Variance-Aware Machine Translation Test Sets License See LICENSE. We follow the data licensing plan as the same as the WMT

NLP2CT Lab, University of Macau 5 Dec 21, 2021
Unoffical reMarkable AddOn for Firefox.

reMarkable for Firefox (Download) This repo converts the offical reMarkable Chrome Extension into a Firefox AddOn published here under the name "Unoff

Jelle Schutter 45 Nov 28, 2022
Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting

Autoformer (NeurIPS 2021) Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting Time series forecasting is a c

THUML @ Tsinghua University 847 Jan 08, 2023
Source code of the paper PatchGraph: In-hand tactile tracking with learned surface normals.

PatchGraph This repository contains the source code of the paper PatchGraph: In-hand tactile tracking with learned surface normals. Installation Creat

Paloma Sodhi 11 Dec 15, 2022
Direct LiDAR Odometry: Fast Localization with Dense Point Clouds

Direct LiDAR Odometry: Fast Localization with Dense Point Clouds DLO is a lightweight and computationally-efficient frontend LiDAR odometry solution w

VECTR at UCLA 369 Dec 30, 2022
a curated list of docker-compose files prepared for testing data engineering tools, databases and open source libraries.

data-services A repository for storing various Data Engineering docker-compose files in one place. How to use it ? Set the required settings in .env f

BigData.IR 525 Dec 03, 2022
Deep Occlusion-Aware Instance Segmentation with Overlapping BiLayers [CVPR 2021]

Deep Occlusion-Aware Instance Segmentation with Overlapping BiLayers [BCNet, CVPR 2021] This is the official pytorch implementation of BCNet built on

Lei Ke 434 Dec 01, 2022
Face Mesh is a face geometry solution that estimates 468 3D face landmarks in real-time even on mobile devices

Face-Mesh Face Mesh is a face geometry solution that estimates 468 3D face landmarks in real-time even on mobile devices. It employs machine learning

Farnam Javadi 9 Dec 21, 2022
This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

Nils L. Westhausen 182 Jan 07, 2023
CVPR2020 Counterfactual Samples Synthesizing for Robust VQA

CVPR2020 Counterfactual Samples Synthesizing for Robust VQA This repo contains code for our paper "Counterfactual Samples Synthesizing for Robust Visu

72 Dec 22, 2022
A small demonstration of using WebDataset with ImageNet and PyTorch Lightning

A small demonstration of using WebDataset with ImageNet and PyTorch Lightning

Tom 50 Dec 16, 2022
[ACM MM 2021] Yes, "Attention is All You Need", for Exemplar based Colorization

Transformer for Image Colorization This is an implemention for Yes, "Attention Is All You Need", for Exemplar based Colorization, and the current soft

Wang Yin 30 Dec 07, 2022