A small library for creating and manipulating custom JAX Pytree classes

Overview

Treeo

A small library for creating and manipulating custom JAX Pytree classes

  • Light-weight: has no dependencies other than jax.
  • Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees.
  • Standards-based: treeo.field is built on top of python's dataclasses.field.
  • Flexible: Treeo is compatible with both dataclass and non-dataclass classes.

Treeo lets you easily create class-based Pytrees so your custom objects can easily interact seamlessly with JAX. Uses of Treeo can range from just creating simple simple JAX-aware utility classes to using it as the core abstraction for full-blown frameworks. Treeo was originally extracted from the core of Treex and shares a lot in common with flax.struct.

Documentation | User Guide

Installation

Install using pip:

pip install treeo

Basics

With Treeo you can easily define your own custom Pytree classes by inheriting from Treeo's Tree class and using the field function to declare which fields are nodes (children) and which are static (metadata):

import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    name: str = to.field(node=False) # I am a static field!

field is just a wrapper around dataclasses.field so you can define your Pytrees as dataclasses, but Treeo fully supports non-dataclass classes as well. Since all Tree instances are Pytree they work with the various functions from thejax library as expected:

p = Person(height=jnp.array(1.8), name="John")

# Trees can be jitted!
jax.jit(lambda person: person)(p) # Person(height=array(1.8), name='John')

# Trees can be mapped!
jax.tree_map(lambda x: 2 * x, p) # Person(height=array(3.6), name='John')

Kinds

Treeo also include a kind system that lets you give semantic meaning to fields (what a field represents within your application). A kind is just a type you pass to field via its kind argument:

class Parameter: pass
class BatchStat: pass

class BatchNorm(to.Tree):
    scale: jnp.ndarray = to.field(node=True, kind=Parameter)
    mean: jnp.ndarray = to.field(node=True, kind=BatchStat)

Kinds are very useful as a filtering mechanism via treeo.filter:

model = BatchNorm(...)

# select only Parameters, mean is filtered out
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)

Nothing behaves like None in Python, but it is a special value that is used to represent the absence of a value within Treeo.

Treeo also offers the merge function which lets you rejoin filtered Trees with a logic similar to Python dict.update but done recursively:

def loss_fn(params, model, ...):
    # add traced params to model
    model = to.merge(model, params)
    ...

# gradient only w.r.t. params
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)
grads = jax.grad(loss_fn)(params, model, ...)

For a more in-depth tour check out the User Guide.

Examples

A simple Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Character(to.Tree):
    position: jnp.ndarray = to.field(node=True)    # node field
    name: str = to.field(node=False, opaque=True)  # static field

character = Character(position=jnp.array([0, 0]), name='Adam')

# character can freely pass through jit
@jax.jit
def update(character: Character, velocity, dt) -> Character:
    character.position += velocity * dt
    return character

character = update(character velocity=jnp.array([1.0, 0.2]), dt=0.1)

A Stateful Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Counter(to.Tree):
    n: jnp.array = to.field(default=jnp.array(0), node=True) # node
    step: int = to.field(default=1, node=False) # static

    def inc(self):
        self.n += self.step

counter = Counter(step=2) # Counter(n=jnp.array(0), step=2)

@jax.jit
def update(counter: Counter):
    counter.inc()
    return counter

counter = update(counter) # Counter(n=jnp.array(2), step=2)

# map over the tree

Full Example - Linear Regression

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import treeo as to


class Linear(to.Tree):
    w: jnp.ndarray = to.node()
    b: jnp.ndarray = to.node()

    def __init__(self, din, dout, key):
        self.w = jax.random.uniform(key, shape=(din, dout))
        self.b = jnp.zeros(shape=(dout,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b


@jax.value_and_grad
def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y) ** 2)

    return loss


def sgd(param, grad):
    return param - 0.1 * grad


@jax.jit
def train_step(model, x, y):
    loss, grads = loss_fn(model, x, y)
    model = jax.tree_map(sgd, model, grads)

    return loss, model


x = np.random.uniform(size=(500, 1))
y = 1.4 * x - 0.3 + np.random.normal(scale=0.1, size=(500, 1))

key = jax.random.PRNGKey(0)
model = Linear(1, 1, key=key)

for step in range(1000):
    loss, model = train_step(model, x, y)
    if step % 100 == 0:
        print(f"loss: {loss:.4f}")

X_test = np.linspace(x.min(), x.max(), 100)[:, None]
y_pred = model(X_test)

plt.scatter(x, y, c="k", label="data")
plt.plot(X_test, y_pred, c="b", linewidth=2, label="prediction")
plt.legend()
plt.show()
Comments
  • Use field kinds within tree_map

    Use field kinds within tree_map

    Firstly, thanks for creating Treeo - it's a fantastic package.

    Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

    import jax.numpy as jnp
    
    class Parameter:
        def transform(self):
            return jnp.exp(self)
    
    
    @dataclass
    class Model(to.Tree):
        lengthscale: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=Parameter
        )
    

    is there a way that I could do something similar to the following pseudocode snippet:

    m = Model()
    jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
    
    opened by thomaspinder 10
  • Stacking of Treeo.Tree

    Stacking of Treeo.Tree

    I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

    from dataclasses import dataclass
    
    import jax
    import jax.numpy as jnp
    import treeo as to
    
    @dataclass
    class Person(to.Tree):
        height: jnp.array = to.field(node=True) # I am a node field!
        age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
        name: str = to.field(node=False) # I am a static field!
    
    persons = [
        Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
        Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
        Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
    ]
    
    # Stack (struct of arrays instead of list of structs)
    jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    

    However, this fails with the following exception:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    Cell In[1], line 18
         11     name: str = to.field(node=False) # I am a static field!
         13 persons = [
         14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
         15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
         16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
         17 ]
    ---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').
    

    Versions used:

    • JAX: 0.3.20
    • Treeo: 0.0.10

    From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields. Has anyone has tried something similar and come up with a nice solution?

    opened by peterroelants 3
  • Jitting twice for a class method

    Jitting twice for a class method

    import jax
    import jax.numpy as jnp
    import treeo as to
    
    class A(to.Tree):
        X: jnp.array = to.field(node=True)
        
        def __init__(self):
            self.X = jnp.ones((50, 50))
    
        @jax.jit
        def f(self, Y):
            return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)
    
    Y = jnp.ones(2)
    for i in range(5):
        print(A.f._cache_size())
        a = A()
        a.f(Y)
    

    The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

    opened by pipme 2
  • Change Mutable API

    Change Mutable API

    Changes

    • Previously self.mutable(*args, method=method, **kwargs)
    • Is now...... self.mutable(method=method)(*args, **kwargs)
    • Opaque API is removed
    • inplace argument is now only available for apply.
    • Immutable.{mutable, toplevel_mutable} methods are removed.
    fix 
    opened by cgarciae 1
  • Improve mutability support

    Improve mutability support

    Changes

    • Fixes issues with immutability in compact context
    • The make_mutable context manager and the mutable function now expose a toplevel_only: bool argument.
    • Adds a _get_unbound_method private function in utils.
    feature 
    opened by cgarciae 1
  • Bug Fixes from 0.0.11

    Bug Fixes from 0.0.11

    Changes

    • Fixes an issues that disabled mutability inside __init__ for Immutable classes when TreeMeta's `constructor method is overloaded.
    • Fixes the Apply.apply mixin method.

    Closes cgarciae/treex#68

    fix 
    opened by cgarciae 1
  • Adds support for immutable Trees

    Adds support for immutable Trees

    Changes

    • Adds an Immutable mixin that can make Trees effectively immutable (as far as python permits).
    • Immutable contains the .replace and .mutable methods that let you manipulate state in a functionally pure fashion.
    • Adds the mutable function transformation / decorator which lets you turn function that perform mutable operation into pure functions.
    opened by cgarciae 1
  • Add the option of using add_field_info inside map

    Add the option of using add_field_info inside map

    This PR addresses the comments made in #2 . An additional argument is created within map to allow for a field_info boolean flag to passed. When true, jax.tree_map is carried out under the with add_field_info(): context manager.

    Tests have been added to test for correct function application on classes contain Trees with mixed kind types.

    A brief section has been added to the documentation to reflect the above changes.

    opened by thomaspinder 1
  • Get all unique kinds

    Get all unique kinds

    Hi,

    Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

    class KindOne: pass
    class KindTwo: pass
    
    @dataclass
    class SubModel(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindOne
        )
    
    
    @dataclass 
    class Model(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindTwo
        )
    
    m = Model()
    
    m.unique_kinds() # [KindOne, KindTwo]
    
    opened by thomaspinder 1
  • Compact

    Compact

    Changes

    • Removes opaque_is_equal, same functionality available through opaque.
    • Adds compact decorator that enable the definition of Tree subnodes at runtime.
    • Adds the Compact mixin that adds the first_run property and the get_field method.
    opened by cgarciae 0
  • Relax jax/jaxlib version constraints

    Relax jax/jaxlib version constraints

    Now that jax 0.3.0 and jaxlib 0.3.0 have been released the version constraints in pyproject.toml are outdated.

    https://github.com/cgarciae/treeo/blob/a402f3f69557840cfbee4d7804964b8e2c47e3f7/pyproject.toml#L16-L17

    This corresponds to the version constraint jax<0.3.0,>=0.2.18 (https://python-poetry.org/docs/dependency-specification/#caret-requirements). Now that jax v0.3.0 has been released (https://github.com/google/jax/releases/tag/jax-v0.3.0) this doesn't work with the latest version. I think the same applies to jaxlib as well, since it also got upgraded to v0.3.0 (https://github.com/google/jax/releases/tag/jaxlib-v0.3.0).

    opened by samuela 4
  • TracedArrays treated as nodes by default

    TracedArrays treated as nodes by default

    Current for convenience all non-Tree fields which are not declared are set to static fields as most fields actually are, however, for more complex applications a Traced Array might actually be passed when a static field is usually expected.

    A simple solution is change the current node policy to treat any field containing a TracedArray as a node, this would be the same as the current policy for Tree fields.

    opened by cgarciae 0
Releases(0.2.1)
Owner
Cristian Garcia
ML Engineer at Quansight, working on Treex and Elegy.
Cristian Garcia
Tensorflow 2 implementation of our high quality frame interpolation neural network

FILM: Frame Interpolation for Large Scene Motion Project | Paper | YouTube | Benchmark Scores Tensorflow 2 implementation of our high quality frame in

Google Research 1.6k Dec 28, 2022
StyleGAN2-ada for practice

This version of the newest PyTorch-based StyleGAN2-ada is intended mostly for fellow artists, who rarely look at scientific metrics, but rather need a working creative tool. Tested on Python 3.7 + Py

vadim epstein 170 Nov 16, 2022
Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision Training Efficiency We show the training efficiency of our DSLP model b

Chenyang Huang 36 Oct 31, 2022
Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution

Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution Abstract Within the Latin (and ancient Greek) production, it is well

4 Dec 03, 2022
Julia package for multiway (inverse) covariance estimation.

TensorGraphicalModels TensorGraphicalModels.jl is a suite of Julia tools for estimating high-dimensional multiway (tensor-variate) covariance and inve

Wayne Wang 3 Sep 23, 2022
Deeprl - Standard DQN and dueling network for simple games

DeepRL This code implements the standard deep Q-learning and dueling network with experience replay (memory buffer) for playing simple games. DQN algo

Yao Zhou 6 Apr 12, 2020
Bridging Vision and Language Model

BriVL BriVL (Bridging Vision and Language Model) 是首个中文通用图文多模态大规模预训练模型。BriVL模型在图文检索任务上有着优异的效果,超过了同期其他常见的多模态预训练模型(例如UNITER、CLIP)。 BriVL论文:WenLan: Bridgi

235 Dec 27, 2022
Implements an infinite sum of poisson-weighted convolutions

An infinite sum of Poisson-weighted convolutions Kyle Cranmer, Aug 2018 If viewing on GitHub, this looks better with nbviewer: click here Consider a v

Kyle Cranmer 26 Dec 07, 2022
A benchmark dataset for emulating atmospheric radiative transfer in weather and climate models with machine learning (NeurIPS 2021 Datasets and Benchmarks Track)

ClimART - A Benchmark Dataset for Emulating Atmospheric Radiative Transfer in Weather and Climate Models Official PyTorch Implementation Using deep le

21 Dec 31, 2022
Making self-supervised learning work on molecules by using their 3D geometry to pre-train GNNs. Implemented in DGL and Pytorch Geometric.

3D Infomax improves GNNs for Molecular Property Prediction Video | Paper We pre-train GNNs to understand the geometry of molecules given only their 2D

Hannes Stärk 95 Dec 30, 2022
Improving Calibration for Long-Tailed Recognition (CVPR2021)

MiSLAS Improving Calibration for Long-Tailed Recognition Authors: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia [arXiv] [slide] [BibTeX] Introductio

Jia Research Lab 116 Dec 20, 2022
Robotics with GPU computing

Robotics with GPU computing Cupoch is a library that implements rapid 3D data processing for robotics using CUDA. The goal of this library is to imple

Shirokuma 625 Jan 07, 2023
The source code for 'Noisy-Labeled NER with Confidence Estimation' accepted by NAACL 2021

Kun Liu*, Yao Fu*, Chuanqi Tan, Mosha Chen, Ningyu Zhang, Songfang Huang, Sheng Gao. Noisy-Labeled NER with Confidence Estimation. NAACL 2021. [arxiv]

30 Nov 12, 2022
CS_Final_Metal_surface_detection - This is a final project for CoderSchool Machine Learning bootcamp on 29/12/2021.

CS_Final_Metal_surface_detection This is a final project for CoderSchool Machine Learning bootcamp on 29/12/2021. The project is based on the dataset

Cuong Vo 1 Dec 29, 2021
Official implementation of the article "Unsupervised JPEG Domain Adaptation For Practical Digital Forensics"

Unsupervised JPEG Domain Adaptation for Practical Digital Image Forensics @WIFS2021 (Montpellier, France) Rony Abecidan, Vincent Itier, Jeremie Boulan

Rony Abecidan 6 Jan 06, 2023
The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL), NeurIPS-2021

Directed Graph Contrastive Learning The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL). In this paper, we present the first con

Tong Zekun 28 Jan 08, 2023
This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch

This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch. The code was prepared to the final version of the accepted manuscript in AIST

Marcelo Hartmann 2 May 06, 2022
PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)

English | 简体中文 Welcome to the PaddlePaddle GitHub. PaddlePaddle, as the only independent R&D deep learning platform in China, has been officially open

19.4k Jan 04, 2023
Official repository for "PAIR: Planning and Iterative Refinement in Pre-trained Transformers for Long Text Generation"

pair-emnlp2020 Official repository for the paper: Xinyu Hua and Lu Wang: PAIR: Planning and Iterative Refinement in Pre-trained Transformers for Long

Xinyu Hua 31 Oct 13, 2022