A small library for creating and manipulating custom JAX Pytree classes

Overview

Treeo

A small library for creating and manipulating custom JAX Pytree classes

  • Light-weight: has no dependencies other than jax.
  • Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees.
  • Standards-based: treeo.field is built on top of python's dataclasses.field.
  • Flexible: Treeo is compatible with both dataclass and non-dataclass classes.

Treeo lets you easily create class-based Pytrees so your custom objects can easily interact seamlessly with JAX. Uses of Treeo can range from just creating simple simple JAX-aware utility classes to using it as the core abstraction for full-blown frameworks. Treeo was originally extracted from the core of Treex and shares a lot in common with flax.struct.

Documentation | User Guide

Installation

Install using pip:

pip install treeo

Basics

With Treeo you can easily define your own custom Pytree classes by inheriting from Treeo's Tree class and using the field function to declare which fields are nodes (children) and which are static (metadata):

import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    name: str = to.field(node=False) # I am a static field!

field is just a wrapper around dataclasses.field so you can define your Pytrees as dataclasses, but Treeo fully supports non-dataclass classes as well. Since all Tree instances are Pytree they work with the various functions from thejax library as expected:

p = Person(height=jnp.array(1.8), name="John")

# Trees can be jitted!
jax.jit(lambda person: person)(p) # Person(height=array(1.8), name='John')

# Trees can be mapped!
jax.tree_map(lambda x: 2 * x, p) # Person(height=array(3.6), name='John')

Kinds

Treeo also include a kind system that lets you give semantic meaning to fields (what a field represents within your application). A kind is just a type you pass to field via its kind argument:

class Parameter: pass
class BatchStat: pass

class BatchNorm(to.Tree):
    scale: jnp.ndarray = to.field(node=True, kind=Parameter)
    mean: jnp.ndarray = to.field(node=True, kind=BatchStat)

Kinds are very useful as a filtering mechanism via treeo.filter:

model = BatchNorm(...)

# select only Parameters, mean is filtered out
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)

Nothing behaves like None in Python, but it is a special value that is used to represent the absence of a value within Treeo.

Treeo also offers the merge function which lets you rejoin filtered Trees with a logic similar to Python dict.update but done recursively:

def loss_fn(params, model, ...):
    # add traced params to model
    model = to.merge(model, params)
    ...

# gradient only w.r.t. params
params = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)
grads = jax.grad(loss_fn)(params, model, ...)

For a more in-depth tour check out the User Guide.

Examples

A simple Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Character(to.Tree):
    position: jnp.ndarray = to.field(node=True)    # node field
    name: str = to.field(node=False, opaque=True)  # static field

character = Character(position=jnp.array([0, 0]), name='Adam')

# character can freely pass through jit
@jax.jit
def update(character: Character, velocity, dt) -> Character:
    character.position += velocity * dt
    return character

character = update(character velocity=jnp.array([1.0, 0.2]), dt=0.1)

A Stateful Tree

from dataclasses import dataclass
import treeo as to

@dataclass
class Counter(to.Tree):
    n: jnp.array = to.field(default=jnp.array(0), node=True) # node
    step: int = to.field(default=1, node=False) # static

    def inc(self):
        self.n += self.step

counter = Counter(step=2) # Counter(n=jnp.array(0), step=2)

@jax.jit
def update(counter: Counter):
    counter.inc()
    return counter

counter = update(counter) # Counter(n=jnp.array(2), step=2)

# map over the tree

Full Example - Linear Regression

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import treeo as to


class Linear(to.Tree):
    w: jnp.ndarray = to.node()
    b: jnp.ndarray = to.node()

    def __init__(self, din, dout, key):
        self.w = jax.random.uniform(key, shape=(din, dout))
        self.b = jnp.zeros(shape=(dout,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b


@jax.value_and_grad
def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y) ** 2)

    return loss


def sgd(param, grad):
    return param - 0.1 * grad


@jax.jit
def train_step(model, x, y):
    loss, grads = loss_fn(model, x, y)
    model = jax.tree_map(sgd, model, grads)

    return loss, model


x = np.random.uniform(size=(500, 1))
y = 1.4 * x - 0.3 + np.random.normal(scale=0.1, size=(500, 1))

key = jax.random.PRNGKey(0)
model = Linear(1, 1, key=key)

for step in range(1000):
    loss, model = train_step(model, x, y)
    if step % 100 == 0:
        print(f"loss: {loss:.4f}")

X_test = np.linspace(x.min(), x.max(), 100)[:, None]
y_pred = model(X_test)

plt.scatter(x, y, c="k", label="data")
plt.plot(X_test, y_pred, c="b", linewidth=2, label="prediction")
plt.legend()
plt.show()
Comments
  • Use field kinds within tree_map

    Use field kinds within tree_map

    Firstly, thanks for creating Treeo - it's a fantastic package.

    Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

    import jax.numpy as jnp
    
    class Parameter:
        def transform(self):
            return jnp.exp(self)
    
    
    @dataclass
    class Model(to.Tree):
        lengthscale: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=Parameter
        )
    

    is there a way that I could do something similar to the following pseudocode snippet:

    m = Model()
    jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
    
    opened by thomaspinder 10
  • Stacking of Treeo.Tree

    Stacking of Treeo.Tree

    I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

    from dataclasses import dataclass
    
    import jax
    import jax.numpy as jnp
    import treeo as to
    
    @dataclass
    class Person(to.Tree):
        height: jnp.array = to.field(node=True) # I am a node field!
        age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
        name: str = to.field(node=False) # I am a static field!
    
    persons = [
        Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
        Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
        Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
    ]
    
    # Stack (struct of arrays instead of list of structs)
    jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    

    However, this fails with the following exception:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    Cell In[1], line 18
         11     name: str = to.field(node=False) # I am a static field!
         13 persons = [
         14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
         15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
         16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
         17 ]
    ---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
        166 """Maps a multi-input function over pytree args to produce a new pytree.
        167 
        168 Args:
       (...)
        196   [[5, 7, 9], [6, 1, 2]]
        197 """
        198 leaves, treedef = tree_flatten(tree, is_leaf)
    --> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
        200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    
    ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').
    

    Versions used:

    • JAX: 0.3.20
    • Treeo: 0.0.10

    From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields. Has anyone has tried something similar and come up with a nice solution?

    opened by peterroelants 3
  • Jitting twice for a class method

    Jitting twice for a class method

    import jax
    import jax.numpy as jnp
    import treeo as to
    
    class A(to.Tree):
        X: jnp.array = to.field(node=True)
        
        def __init__(self):
            self.X = jnp.ones((50, 50))
    
        @jax.jit
        def f(self, Y):
            return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)
    
    Y = jnp.ones(2)
    for i in range(5):
        print(A.f._cache_size())
        a = A()
        a.f(Y)
    

    The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

    opened by pipme 2
  • Change Mutable API

    Change Mutable API

    Changes

    • Previously self.mutable(*args, method=method, **kwargs)
    • Is now...... self.mutable(method=method)(*args, **kwargs)
    • Opaque API is removed
    • inplace argument is now only available for apply.
    • Immutable.{mutable, toplevel_mutable} methods are removed.
    fix 
    opened by cgarciae 1
  • Improve mutability support

    Improve mutability support

    Changes

    • Fixes issues with immutability in compact context
    • The make_mutable context manager and the mutable function now expose a toplevel_only: bool argument.
    • Adds a _get_unbound_method private function in utils.
    feature 
    opened by cgarciae 1
  • Bug Fixes from 0.0.11

    Bug Fixes from 0.0.11

    Changes

    • Fixes an issues that disabled mutability inside __init__ for Immutable classes when TreeMeta's `constructor method is overloaded.
    • Fixes the Apply.apply mixin method.

    Closes cgarciae/treex#68

    fix 
    opened by cgarciae 1
  • Adds support for immutable Trees

    Adds support for immutable Trees

    Changes

    • Adds an Immutable mixin that can make Trees effectively immutable (as far as python permits).
    • Immutable contains the .replace and .mutable methods that let you manipulate state in a functionally pure fashion.
    • Adds the mutable function transformation / decorator which lets you turn function that perform mutable operation into pure functions.
    opened by cgarciae 1
  • Add the option of using add_field_info inside map

    Add the option of using add_field_info inside map

    This PR addresses the comments made in #2 . An additional argument is created within map to allow for a field_info boolean flag to passed. When true, jax.tree_map is carried out under the with add_field_info(): context manager.

    Tests have been added to test for correct function application on classes contain Trees with mixed kind types.

    A brief section has been added to the documentation to reflect the above changes.

    opened by thomaspinder 1
  • Get all unique kinds

    Get all unique kinds

    Hi,

    Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

    class KindOne: pass
    class KindTwo: pass
    
    @dataclass
    class SubModel(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindOne
        )
    
    
    @dataclass 
    class Model(to.Tree):
        parameter: jnp.array = to.field(
            default=jnp.array([1.0]), node=True, kind=KindTwo
        )
    
    m = Model()
    
    m.unique_kinds() # [KindOne, KindTwo]
    
    opened by thomaspinder 1
  • Compact

    Compact

    Changes

    • Removes opaque_is_equal, same functionality available through opaque.
    • Adds compact decorator that enable the definition of Tree subnodes at runtime.
    • Adds the Compact mixin that adds the first_run property and the get_field method.
    opened by cgarciae 0
  • Relax jax/jaxlib version constraints

    Relax jax/jaxlib version constraints

    Now that jax 0.3.0 and jaxlib 0.3.0 have been released the version constraints in pyproject.toml are outdated.

    https://github.com/cgarciae/treeo/blob/a402f3f69557840cfbee4d7804964b8e2c47e3f7/pyproject.toml#L16-L17

    This corresponds to the version constraint jax<0.3.0,>=0.2.18 (https://python-poetry.org/docs/dependency-specification/#caret-requirements). Now that jax v0.3.0 has been released (https://github.com/google/jax/releases/tag/jax-v0.3.0) this doesn't work with the latest version. I think the same applies to jaxlib as well, since it also got upgraded to v0.3.0 (https://github.com/google/jax/releases/tag/jaxlib-v0.3.0).

    opened by samuela 4
  • TracedArrays treated as nodes by default

    TracedArrays treated as nodes by default

    Current for convenience all non-Tree fields which are not declared are set to static fields as most fields actually are, however, for more complex applications a Traced Array might actually be passed when a static field is usually expected.

    A simple solution is change the current node policy to treat any field containing a TracedArray as a node, this would be the same as the current policy for Tree fields.

    opened by cgarciae 0
Releases(0.2.1)
Owner
Cristian Garcia
ML Engineer at Quansight, working on Treex and Elegy.
Cristian Garcia
The Multi-Mission Maximum Likelihood framework (3ML)

PyPi Conda The Multi-Mission Maximum Likelihood framework (3ML) A framework for multi-wavelength/multi-messenger analysis for astronomy/astrophysics.

The Multi-Mission Maximum Likelihood (3ML) 62 Dec 30, 2022
The official PyTorch implementation for the paper "sMGC: A Complex-Valued Graph Convolutional Network via Magnetic Laplacian for Directed Graphs".

Magnetic Graph Convolutional Networks About The official PyTorch implementation for the paper sMGC: A Complex-Valued Graph Convolutional Network via M

3 Feb 25, 2022
Implementation of the paper "Language-agnostic representation learning of source code from structure and context".

Code Transformer This is an official PyTorch implementation of the CodeTransformer model proposed in: D. Zügner, T. Kirschstein, M. Catasta, J. Leskov

Daniel Zügner 131 Dec 13, 2022
Local trajectory planner based on a multilayer graph framework for autonomous race vehicles.

Graph-Based Local Trajectory Planner The graph-based local trajectory planner is python-based and comes with open interfaces as well as debug, visuali

TUM - Institute of Automotive Technology 160 Jan 04, 2023
This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds

LiDARTag Overview This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds (PDF)(arXiv). This wo

University of Michigan Dynamic Legged Locomotion Robotics Lab 159 Dec 21, 2022
Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning.

xTune Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning. Environment DockerFile: dancingsoul/pytorch:xTune Install the f

Bo Zheng 42 Dec 09, 2022
Code for the published paper : Learning to recognize rare traffic sign

Improving traffic sign recognition by active search This repo contains code for the paper : "Learning to recognise rare traffic signs" How to use this

samsja 4 Jan 05, 2023
Subnet Replacement Attack: Towards Practical Deployment-Stage Backdoor Attack on Deep Neural Networks

Subnet Replacement Attack: Towards Practical Deployment-Stage Backdoor Attack on Deep Neural Networks Official implementation of paper Towards Practic

Xiangyu Qi 8 Dec 30, 2022
Convnet transfer - Code for paper How transferable are features in deep neural networks?

How transferable are features in deep neural networks? This repository contains source code necessary to reproduce the results presented in the follow

Jason Yosinski 143 Sep 13, 2022
ICON: Implicit Clothed humans Obtained from Normals

ICON: Implicit Clothed humans Obtained from Normals arXiv, December 2021. Yuliang Xiu · Jinlong Yang · Dimitrios Tzionas · Michael J. Black Table of C

Yuliang Xiu 1.1k Dec 30, 2022
Rotation Robust Descriptors

RoRD Rotation-Robust Descriptors and Orthographic Views for Local Feature Matching Project Page | Paper link Evaluation and Datasets MMA : Training on

Udit Singh Parihar 25 Nov 15, 2022
SAT Project - The first project I had done at General Assembly, performed EDA, data cleaning and created data visualizations

Project 1: Standardized Test Analysis by Adam Klesc Overview This project covers: Basic statistics and probability Many Python programming concepts Pr

Adam Muhammad Klesc 1 Jan 03, 2022
Lane assist for ETS2, built with the ultra-fast-lane-detection model.

Euro-Truck-Simulator-2-Lane-Assist Lane assist for ETS2, built with the ultra-fast-lane-detection model. This project was made possible by the amazing

36 Jan 05, 2023
Code for "Multi-View Multi-Person 3D Pose Estimation with Plane Sweep Stereo"

Multi-View Multi-Person 3D Pose Estimation with Plane Sweep Stereo This repository includes the source code for our CVPR 2021 paper on multi-view mult

Jiahao Lin 66 Jan 04, 2023
Weakly Supervised Segmentation with Tensorflow. Implements instance segmentation as described in Simple Does It: Weakly Supervised Instance and Semantic Segmentation, by Khoreva et al. (CVPR 2017).

Weakly Supervised Segmentation with TensorFlow This repo contains a TensorFlow implementation of weakly supervised instance segmentation as described

Phil Ferriere 220 Dec 13, 2022
IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales

IRON Kaggle project done while doing IRONHACK Bootcamp where we had to analyze and use a Machine Learning Project to predict future sales. In this case, we ended up using XGBoost because it was the o

1 Jan 04, 2022
A no-BS, dead-simple training visualizer for tf-keras

A no-BS, dead-simple training visualizer for tf-keras TrainingDashboard Plot inter-epoch and intra-epoch loss and metrics within a jupyter notebook wi

Vibhu Agrawal 3 May 28, 2021
Implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT : Cross-Attention Multi-Scale Vision Transformer for Image Classification This is an unofficial PyTorch implementation of CrossViT: Cross-Att

Rishikesh (ऋषिकेश) 103 Nov 25, 2022
Deduplicating Training Data Makes Language Models Better

Deduplicating Training Data Makes Language Models Better This repository contains code to deduplicate language model datasets as descrbed in the paper

Google Research 431 Dec 27, 2022
Contrastive Learning of Image Representations with Cross-Video Cycle-Consistency

Contrastive Learning of Image Representations with Cross-Video Cycle-Consistency This is a official implementation of the CycleContrast introduced in

13 Nov 14, 2022