JAX + dataclasses

Overview

jax_dataclasses

build mypy lint codecov

jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

  • Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
  • Serialization via flax.serialization.

Notably, jax_dataclasses is designed to work seamlessly with static analysis, including tools like mypy and jedi.

Heavily influenced by some great existing work; see Alternatives for comparisons.

Installation

pip install jax_dataclasses

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass:

  • jax_dataclasses.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree container.
  • jax_dataclasses.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

We also provide several aliases: jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, we provide an interface that will (a) make a copy of a pytree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses

@jax_dataclasses.pytree_dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
  # and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

  • Static analysis support. Libraries like dataclasses and attrs rely on tooling-specific custom plugins for static analysis, which don't exist for chex or flax. tjax has a custom mypy plugin to enable type checking, but isn't supported by other tools. Because @jax_dataclasses.pytree_dataclass has the same API as @dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard dataclasses library should work out of the box with jax_dataclasses.

  • Nested dataclasses. Making replacements/modifications in deeply nested dataclasses is generally very frustrating. The three alternatives all introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested. jax_dataclasses.copy_and_mutate() is introduced to address this.

  • Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in flax, tjax, and jax_dataclasses, but not chex.

  • Serialization. When working with flax, being able to serialize dataclasses is really handy. This is supported in flax.struct (naturally) and jax_dataclasses, but not chex or tjax.

Misc

This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback!

Comments
  • Fix infinite loop for cycles in pytrees

    Fix infinite loop for cycles in pytrees

    I have a rather big dataclass to describe a robot model, that includes a graph of links and a list of joints. Each node of the graph references the parent link and all the child links. Each joint object references its parent and child links.

    When I try to copy_and_mutate any of these objects, maybe due to all this nesting, an infinite loop occurs. I suspect that the existing logic tries to unfreeze all the leafs of the pytree, but the high interconnection and the properties of mutable Python types lead to a never ending unfreezing process.

    This PR addresses this edge case by storing the list of IDs of objects already unfreezed. It solves my problem, and it should not add any noticeable performance degradation.

    cc @brentyi

    opened by diegoferigo 10
  • Delayed initialisation of static fields

    Delayed initialisation of static fields

    First of all, thank you for the amazing library! I have recently discovered jax_dataclasses and I have decided to port my messy JAX functional code to a more organised object-oriented code based on jax_dataclasses.

    In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.

    Here is a simple example, where the attribute _sum is a derived static field that depends on the constant value of the array a.

    import jax
    import jax.numpy as jnp
    import jax_dataclasses as jdc
    
    @jdc.pytree_dataclass()
    class PyTreeDataclass:
        a: jnp.ndarray
        _sum: int = jdc.static_field(init=False, repr=False)
    
        def __post_init__(self):
            object.__setattr__(self, "_sum", self.a.sum().item())
    
    def print_pytree(obj):
        print(obj._sum)
    
    obj = PyTreeDataclass(jnp.arange(4))
    print_pytree(obj)
    jax.jit(print_pytree)(obj)
    

    The non-jitted version works, but when print_pytree is jitted I get the following error.

    File "jax_dataclasses_issue.py", line 14, in __post_init__
        object.__setattr__(self, "_sum", self.a.sum().item())
    AttributeError: 'bool' object has no attribute 'sum'
    

    Is there a way to compute in the __post_init__ the value of static fields not initialized in __init__ that depend on jnp.ndarray attributes of the dataclass?

    opened by lucagrementieri 4
  • `jax.tree_leaves` is deprecated

    `jax.tree_leaves` is deprecated

    The file jax_dataclasses/_copy_and_mutate.py raises many warnings complaining a deprecated function.

    FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    
    opened by lucagrementieri 1
  • Use jaxtyping to enrich type annotations

    Use jaxtyping to enrich type annotations

    I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.

    jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.

    Do you think that it could be added to jax_dataclasses?

    opened by lucagrementieri 4
  • Serialization of static fields?

    Serialization of static fields?

    Thanks for the handy library!

    I have a pytree_dataclass that contains a few static_fields that I would like to have serialized by the facilities in flax.serialize. I noticed that jax_dataclasses.asdict handles these, but that flax.serialization.to_state_dict and flax.serialization.to_bytes both ignore them. What is the correct way (if any) to have these fields included in flax's serialization? Should I be using another technique?

    import jax_dataclasses as jdc
    from jax import numpy as jnp
    import flax.serialization as fs
    
    
    @jdc.pytree_dataclass
    class Demo:
        a: jnp.ndarray = jnp.ones(3)
        b: bool = jdc.static_field(default=False)
    
    
    demo = Demo()
    print(f'{jdc.asdict(demo) = }')
    print(f'{fs.to_state_dict(demo) = }')
    print(f'{fs.from_bytes(Demo, fs.to_bytes(demo)) = }')
    
    # jdc.asdict(demo) = {'a': array([1., 1., 1.]), 'b': False}
    # fs.to_state_dict(demo) = {'a': DeviceArray([1., 1., 1.], dtype=float64)}
    # fs.from_bytes(Demo, fs.to_bytes(demo)) = {'a': array([1., 1., 1.])}
    

    Thanks in advance!

    opened by erdmann 3
Releases(v1.5.1)
Owner
Brent Yi
Brent Yi
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

DALL-E in Pytorch Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the ge

Phil Wang 5k Jan 04, 2023
MlTr: Multi-label Classification with Transformer

MlTr: Multi-label Classification with Transformer This is official implement of "MlTr: Multi-label Classification with Transformer". Abstract The task

程星 38 Nov 08, 2022
[CVPR2021 Oral] End-to-End Video Instance Segmentation with Transformers

VisTR: End-to-End Video Instance Segmentation with Transformers This is the official implementation of the VisTR paper: Installation We provide instru

Yuqing Wang 687 Jan 07, 2023
Open source hardware and software platform to build a small scale self driving car.

Donkeycar is minimalist and modular self driving library for Python. It is developed for hobbyists and students with a focus on allowing fast experimentation and easy community contributions.

Autorope 2.4k Jan 04, 2023
Filtering variational quantum algorithms for combinatorial optimization

Current gate-based quantum computers have the potential to provide a computational advantage if algorithms use quantum hardware efficiently.

1 Feb 09, 2022
SelfRemaster: SSL Speech Restoration

SelfRemaster: Self-Supervised Speech Restoration Official implementation of SelfRemaster: Self-Supervised Speech Restoration with Analysis-by-Synthesi

Takaaki Saeki 46 Jan 07, 2023
Faster RCNN pytorch windows

Faster-RCNN-pytorch-windows Faster RCNN implementation with pytorch for windows Open cmd, compile this comands: cd lib python setup.py build develop T

Hwa-Rang Kim 1 Nov 11, 2022
Learning Domain Invariant Representations in Goal-conditioned Block MDPs

Learning Domain Invariant Representations in Goal-conditioned Block MDPs Beining Han, Chongyi Zheng, Harris Chan, Keiran Paster, Michael R. Zhang, Jim

Chongyi Zheng 3 Apr 12, 2022
iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis

iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis Andreas Bl

CompVis Heidelberg 36 Dec 25, 2022
Reverse engineer your pytorch vision models, in style

🔍 Rover Reverse engineer your CNNs, in style Rover will help you break down your CNN and visualize the features from within the model. No need to wri

Mayukh Deb 32 Sep 24, 2022
State-of-the-art language models can match human performance on many tasks

Status: Archive (code is provided as-is, no updates expected) Grade School Math [Blog Post] [Paper] State-of-the-art language models can match human p

OpenAI 259 Jan 08, 2023
Implements a fake news detection program using classifiers.

Fake news detection Implements a fake news detection program using classifiers for Data Mining course at UoA. Description The project is the categoriz

Apostolos Karvelas 1 Jan 09, 2022
Serving PyTorch 1.0 Models as a Web Server in C++

Serving PyTorch Models in C++ This repository contains various examples to perform inference using PyTorch C++ API. Run git clone https://github.com/W

Onur Kaplan 223 Jan 04, 2023
Oriented Object Detection: Oriented RepPoints + Swin Transformer/ReResNet

Oriented RepPoints for Aerial Object Detection The code for the implementation of “Oriented RepPoints + Swin Transformer/ReResNet”. Introduction Based

96 Dec 13, 2022
LaneAF: Robust Multi-Lane Detection with Affinity Fields

LaneAF: Robust Multi-Lane Detection with Affinity Fields This repository contains Pytorch code for training and testing LaneAF lane detection models i

155 Dec 17, 2022
LAVT: Language-Aware Vision Transformer for Referring Image Segmentation

LAVT: Language-Aware Vision Transformer for Referring Image Segmentation Where we are ? 12.27 目前和原论文仍有1%左右得差距,但已经力压很多SOTA了 ckpt__448_epoch_25.pth mIoU

zichengsaber 60 Dec 11, 2022
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Rishabh Anand 24 Mar 23, 2022
Distributional Sliced-Wasserstein distance code

Distributional Sliced Wasserstein distance This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Genera

VinAI Research 39 Jan 01, 2023
use tensorflow 2.0 to tell a dog and cat from a specified picture

dog_or_cat use tensorflow 2.0 to tell a dog and cat from a specified picture This is one of the classic experiments for the introduction of deep learn

你这个代码我看不懂 1 Oct 22, 2021
🇰🇷 Text to Image in Korean

KoDALLE Utilizing pretrained language model’s token embedding layer and position embedding layer as DALLE’s text encoder. Background Training DALLE mo

HappyFace 74 Sep 22, 2022