JAX + dataclasses

Overview

jax_dataclasses

build mypy lint codecov

jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

  • Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
  • Serialization via flax.serialization.

Notably, jax_dataclasses is designed to work seamlessly with static analysis, including tools like mypy and jedi.

Heavily influenced by some great existing work; see Alternatives for comparisons.

Installation

pip install jax_dataclasses

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass:

  • jax_dataclasses.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree container.
  • jax_dataclasses.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

We also provide several aliases: jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, we provide an interface that will (a) make a copy of a pytree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses

@jax_dataclasses.pytree_dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
  # and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

  • Static analysis support. Libraries like dataclasses and attrs rely on tooling-specific custom plugins for static analysis, which don't exist for chex or flax. tjax has a custom mypy plugin to enable type checking, but isn't supported by other tools. Because @jax_dataclasses.pytree_dataclass has the same API as @dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard dataclasses library should work out of the box with jax_dataclasses.

  • Nested dataclasses. Making replacements/modifications in deeply nested dataclasses is generally very frustrating. The three alternatives all introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested. jax_dataclasses.copy_and_mutate() is introduced to address this.

  • Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in flax, tjax, and jax_dataclasses, but not chex.

  • Serialization. When working with flax, being able to serialize dataclasses is really handy. This is supported in flax.struct (naturally) and jax_dataclasses, but not chex or tjax.

Misc

This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback!

Comments
  • Fix infinite loop for cycles in pytrees

    Fix infinite loop for cycles in pytrees

    I have a rather big dataclass to describe a robot model, that includes a graph of links and a list of joints. Each node of the graph references the parent link and all the child links. Each joint object references its parent and child links.

    When I try to copy_and_mutate any of these objects, maybe due to all this nesting, an infinite loop occurs. I suspect that the existing logic tries to unfreeze all the leafs of the pytree, but the high interconnection and the properties of mutable Python types lead to a never ending unfreezing process.

    This PR addresses this edge case by storing the list of IDs of objects already unfreezed. It solves my problem, and it should not add any noticeable performance degradation.

    cc @brentyi

    opened by diegoferigo 10
  • Delayed initialisation of static fields

    Delayed initialisation of static fields

    First of all, thank you for the amazing library! I have recently discovered jax_dataclasses and I have decided to port my messy JAX functional code to a more organised object-oriented code based on jax_dataclasses.

    In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.

    Here is a simple example, where the attribute _sum is a derived static field that depends on the constant value of the array a.

    import jax
    import jax.numpy as jnp
    import jax_dataclasses as jdc
    
    @jdc.pytree_dataclass()
    class PyTreeDataclass:
        a: jnp.ndarray
        _sum: int = jdc.static_field(init=False, repr=False)
    
        def __post_init__(self):
            object.__setattr__(self, "_sum", self.a.sum().item())
    
    def print_pytree(obj):
        print(obj._sum)
    
    obj = PyTreeDataclass(jnp.arange(4))
    print_pytree(obj)
    jax.jit(print_pytree)(obj)
    

    The non-jitted version works, but when print_pytree is jitted I get the following error.

    File "jax_dataclasses_issue.py", line 14, in __post_init__
        object.__setattr__(self, "_sum", self.a.sum().item())
    AttributeError: 'bool' object has no attribute 'sum'
    

    Is there a way to compute in the __post_init__ the value of static fields not initialized in __init__ that depend on jnp.ndarray attributes of the dataclass?

    opened by lucagrementieri 4
  • `jax.tree_leaves` is deprecated

    `jax.tree_leaves` is deprecated

    The file jax_dataclasses/_copy_and_mutate.py raises many warnings complaining a deprecated function.

    FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    
    opened by lucagrementieri 1
  • Use jaxtyping to enrich type annotations

    Use jaxtyping to enrich type annotations

    I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.

    jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.

    Do you think that it could be added to jax_dataclasses?

    opened by lucagrementieri 4
  • Serialization of static fields?

    Serialization of static fields?

    Thanks for the handy library!

    I have a pytree_dataclass that contains a few static_fields that I would like to have serialized by the facilities in flax.serialize. I noticed that jax_dataclasses.asdict handles these, but that flax.serialization.to_state_dict and flax.serialization.to_bytes both ignore them. What is the correct way (if any) to have these fields included in flax's serialization? Should I be using another technique?

    import jax_dataclasses as jdc
    from jax import numpy as jnp
    import flax.serialization as fs
    
    
    @jdc.pytree_dataclass
    class Demo:
        a: jnp.ndarray = jnp.ones(3)
        b: bool = jdc.static_field(default=False)
    
    
    demo = Demo()
    print(f'{jdc.asdict(demo) = }')
    print(f'{fs.to_state_dict(demo) = }')
    print(f'{fs.from_bytes(Demo, fs.to_bytes(demo)) = }')
    
    # jdc.asdict(demo) = {'a': array([1., 1., 1.]), 'b': False}
    # fs.to_state_dict(demo) = {'a': DeviceArray([1., 1., 1.], dtype=float64)}
    # fs.from_bytes(Demo, fs.to_bytes(demo)) = {'a': array([1., 1., 1.])}
    

    Thanks in advance!

    opened by erdmann 3
Releases(v1.5.1)
Owner
Brent Yi
Brent Yi
Offline Reinforcement Learning with Implicit Q-Learning

Offline Reinforcement Learning with Implicit Q-Learning This repository contains the official implementation of Offline Reinforcement Learning with Im

Ilya Kostrikov 125 Dec 31, 2022
FaceOcc: A Diverse, High-quality Face Occlusion Dataset for Human Face Extraction

FaceExtraction FaceOcc: A Diverse, High-quality Face Occlusion Dataset for Human Face Extraction Occlusions often occur in face images in the wild, tr

16 Dec 14, 2022
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Equinox Callable PyTrees and filtered JIT/grad transformations = neural networks in JAX Equinox brings more power to your model building in JAX. Repr

Patrick Kidger 909 Dec 30, 2022
RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching

RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching This repository contains the source code for our paper: RAFT-Stereo: Multilevel

Princeton Vision & Learning Lab 328 Jan 09, 2023
Implementation of TabTransformer, attention network for tabular data, in Pytorch

Tab Transformer Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's bread

Phil Wang 420 Jan 05, 2023
Official PyTorch Implementation of SSMix (Findings of ACL 2021)

SSMix: Saliency-based Span Mixup for Text Classification (Findings of ACL 2021) Official PyTorch Implementation of SSMix | Paper Abstract Data augment

Clova AI Research 52 Dec 27, 2022
El-Gamal on Elliptic Curve (Python)

El-Gamal-on-EC El-Gamal on Elliptic Curve (Python) References: https://docsdrive.com/pdfs/ansinet/itj/2005/299-306.pdf https://arxiv.org/ftp/arxiv/pap

3 May 04, 2022
Referring Video Object Segmentation

Awesome-Referring-Video-Object-Segmentation Welcome to starts ⭐ & comments 💹 & sharing 😀 !! - 2021.12.12: Recent papers (from 2021) - welcome to ad

Explorer 57 Dec 11, 2022
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
GraphLily: A Graph Linear Algebra Overlay on HBM-Equipped FPGAs

GraphLily: A Graph Linear Algebra Overlay on HBM-Equipped FPGAs GraphLily is the first FPGA overlay for graph processing. GraphLily supports a rich se

Cornell Zhang Research Group 39 Dec 13, 2022
Codes for paper "Towards Diverse Paragraph Captioning for Untrimmed Videos". CVPR 2021

Towards Diverse Paragraph Captioning for Untrimmed Videos This repository contains PyTorch implementation of our paper Towards Diverse Paragraph Capti

Yuqing Song 61 Oct 11, 2022
Trax — Deep Learning with Clear Code and Speed

Trax — Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us

Google 7.3k Dec 26, 2022
An efficient framework for reinforcement learning.

rl: An efficient framework for reinforcement learning Requirements Introduction PPO Test Requirements name version Python =3.7 numpy =1.19 torch =1

16 Nov 30, 2022
RINDNet: Edge Detection for Discontinuity in Reflectance, Illumination, Normal and Depth, in ICCV 2021 (oral)

RINDNet RINDNet: Edge Detection for Discontinuity in Reflectance, Illumination, Normal and Depth Mengyang Pu, Yaping Huang, Qingji Guan and Haibin Lin

Mengyang Pu 75 Dec 15, 2022
SCALE: Modeling Clothed Humans with a Surface Codec of Articulated Local Elements (CVPR 2021)

SCALE: Modeling Clothed Humans with a Surface Codec of Articulated Local Elements (CVPR 2021) This repository contains the official PyTorch implementa

Qianli Ma 133 Jan 05, 2023
KoRean based ELECTRA pre-trained models (KR-ELECTRA) for Tensorflow and PyTorch

KoRean based ELECTRA (KR-ELECTRA) This is a release of a Korean-specific ELECTRA model with comparable or better performances developed by the Computa

12 Jun 03, 2022
Reverse engineer your pytorch vision models, in style

🔍 Rover Reverse engineer your CNNs, in style Rover will help you break down your CNN and visualize the features from within the model. No need to wri

Mayukh Deb 32 Sep 24, 2022
Flow is a computational framework for deep RL and control experiments for traffic microsimulation.

Flow Flow is a computational framework for deep RL and control experiments for traffic microsimulation. See our website for more information on the ap

867 Jan 02, 2023
Code for the paper "Improving Vision-and-Language Navigation with Image-Text Pairs from the Web" (ECCV 2020)

Improving Vision-and-Language Navigation with Image-Text Pairs from the Web Arjun Majumdar, Ayush Shrivastava, Stefan Lee, Peter Anderson, Devi Parikh

Arjun Majumdar 44 Dec 14, 2022
Haze Removal can remove slight to extreme cases of haze affecting an image

Haze Removal can remove slight to extreme cases of haze affecting an image. Its most typical use is for landscape photography where the haze causes low contrast and low saturation, but it can also be

Grace Ugochi Nneji 3 Feb 15, 2022