tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
Repository aimed at compiling code, papers, demos etc.. related to my PhD on 3D vision and machine learning for fruit detection and shape estimation at the university of Lincoln

PhD_3DPerception Repository aimed at compiling code, papers, demos etc.. related to my PhD on 3D vision and machine learning for fruit detection and s

lelouedec 2 Oct 06, 2022
VOneNet: CNNs with a Primary Visual Cortex Front-End

VOneNet: CNNs with a Primary Visual Cortex Front-End A family of biologically-inspired Convolutional Neural Networks (CNNs). VOneNets have the followi

The DiCarlo Lab at MIT 99 Dec 22, 2022
ML From Scratch

ML from Scratch MACHINE LEARNING TOPICS COVERED - FROM SCRATCH Linear Regression Logistic Regression K Means Clustering K Nearest Neighbours Decision

Tanishq Gautam 66 Nov 02, 2022
Toward Spatially Unbiased Generative Models (ICCV 2021)

Toward Spatially Unbiased Generative Models Implementation of Toward Spatially Unbiased Generative Models (ICCV 2021) Overview Recent image generation

Jooyoung Choi 88 Dec 01, 2022
Plover-tapey-tape: an alternative to Plover’s built-in paper tape

plover-tapey-tape plover-tapey-tape is an alternative to Plover’s built-in paper

7 May 29, 2022
BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalanced Tongue Data

Balanced-Evolutionary-Semi-Stacking Code for the paper ''BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalan

0 Jan 16, 2022
RuleBERT: Teaching Soft Rules to Pre-Trained Language Models

RuleBERT: Teaching Soft Rules to Pre-Trained Language Models (Paper) (Slides) (Video) RuleBERT is a pre-trained language model that has been fine-tune

16 Aug 24, 2022
A plug-and-play library for neural networks written in Python

A plug-and-play library for neural networks written in Python!

Dimos Michailidis 2 Jul 16, 2022
Implementation of "DeepOrder: Deep Learning for Test Case Prioritization in Continuous Integration Testing".

DeepOrder Implementation of DeepOrder for the paper "DeepOrder: Deep Learning for Test Case Prioritization in Continuous Integration Testing". Project

6 Nov 07, 2022
Source code of AAAI 2022 paper "Towards End-to-End Image Compression and Analysis with Transformers".

Towards End-to-End Image Compression and Analysis with Transformers Source code of our AAAI 2022 paper "Towards End-to-End Image Compression and Analy

37 Dec 21, 2022
[CVPR'22] Weakly Supervised Semantic Segmentation by Pixel-to-Prototype Contrast

wseg Overview The Pytorch implementation of Weakly Supervised Semantic Segmentation by Pixel-to-Prototype Contrast. [arXiv] Though image-level weakly

Ye Du 96 Dec 30, 2022
Official codes for the paper "Learning Hierarchical Discrete Linguistic Units from Visually-Grounded Speech"

ResDAVEnet-VQ Official PyTorch implementation of Learning Hierarchical Discrete Linguistic Units from Visually-Grounded Speech What is in this repo? M

Wei-Ning Hsu 21 Aug 23, 2022
An official implementation of the Anchor DETR.

Anchor DETR: Query Design for Transformer-Based Detector Introduction This repository is an official implementation of the Anchor DETR. We encode the

MEGVII Research 276 Dec 28, 2022
A certifiable defense against adversarial examples by training neural networks to be provably robust

DiffAI v3 DiffAI is a system for training neural networks to be provably robust and for proving that they are robust. The system was developed for the

SRI Lab, ETH Zurich 202 Dec 13, 2022
Detail-Preserving Transformer for Light Field Image Super-Resolution

DPT Official Pytorch implementation of the paper "Detail-Preserving Transformer for Light Field Image Super-Resolution" accepted by AAAI 2022 . Update

50 Jan 01, 2023
An 16kHz implementation of HiFi-GAN for soft-vc.

HiFi-GAN An 16kHz implementation of HiFi-GAN for soft-vc. Relevant links: Official HiFi-GAN repo HiFi-GAN paper Soft-VC repo Soft-VC paper Example Usa

Benjamin van Niekerk 42 Dec 27, 2022
An end-to-end machine learning library to directly optimize AUC loss

LibAUC An end-to-end machine learning library for AUC optimization. Why LibAUC? Deep AUC Maximization (DAM) is a paradigm for learning a deep neural n

Andrew 75 Dec 12, 2022
A Runtime method overload decorator which should behave like a compiled language

strongtyping-pyoverload A Runtime method overload decorator which should behave like a compiled language there is a override decorator from typing whi

20 Oct 31, 2022
Text-to-SQL in the Wild: A Naturally-Occurring Dataset Based on Stack Exchange Data

SEDE SEDE (Stack Exchange Data Explorer) is new dataset for Text-to-SQL tasks with more than 12,000 SQL queries and their natural language description

Rupert. 83 Nov 11, 2022
Action Segmentation Evaluation

Reference Action Segmentation Evaluation Code This repository contains the reference code for action segmentation evaluation. If you have a bug-fix/im

5 May 22, 2022