JAX + dataclasses

Overview

jax_dataclasses

build mypy lint codecov

jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for:

  • Pytree registration. This allows dataclasses to be used at API boundaries in JAX. (necessary for function transformations, JIT, etc)
  • Serialization via flax.serialization.

Notably, jax_dataclasses is designed to work seamlessly with static analysis, including tools like mypy and jedi.

Heavily influenced by some great existing work; see Alternatives for comparisons.

Installation

pip install jax_dataclasses

Core interface

jax_dataclasses is meant to provide a drop-in replacement for dataclasses.dataclass:

  • jax_dataclasses.pytree_dataclass has the same interface as dataclasses.dataclass, but also registers the target class as a pytree container.
  • jax_dataclasses.static_field has the same interface as dataclasses.field, but will also mark the field as static. In a pytree node, static fields will be treated as part of the treedef instead of as a child of the node; all fields that are not explicitly marked static should contain arrays or child nodes.

We also provide several aliases: jax_dataclasses.[field, asdict, astuples, is_dataclass, replace] are all identical to their counterparts in the standard dataclasses library.

Mutations

All dataclasses are automatically marked as frozen and thus immutable (even when no frozen= parameter is passed in). To make changes to nested structures easier, we provide an interface that will (a) make a copy of a pytree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:

from jax import numpy as jnp
import jax_dataclasses

@jax_dataclasses.pytree_dataclass
class Node:
  child: jnp.ndarray

obj = Node(child=jnp.zeros(3))

with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
  # Make mutations to the dataclass. This is primarily useful for nested
  # dataclasses.
  #
  # Also does input validation: if the treedef, leaf shapes, or dtypes of `obj`
  # and `obj_updated` don't match, an AssertionError will be raised.
  # This can be disabled with a `validate=False` argument.
  obj_updated.child = jnp.ones(3)

print(obj)
print(obj_updated)

Alternatives

A few other solutions exist for automatically integrating dataclass-style objects into pytree structures. Great ones include: chex.dataclass, flax.struct, and tjax.dataclass. These all influenced this library.

The main differentiators of jax_dataclasses are:

  • Static analysis support. Libraries like dataclasses and attrs rely on tooling-specific custom plugins for static analysis, which don't exist for chex or flax. tjax has a custom mypy plugin to enable type checking, but isn't supported by other tools. Because @jax_dataclasses.pytree_dataclass has the same API as @dataclasses.dataclass, it can include pytree registration behavior at runtime while being treated as the standard decorator during static analysis. This means that all static checkers, language servers, and autocomplete engines that support the standard dataclasses library should work out of the box with jax_dataclasses.

  • Nested dataclasses. Making replacements/modifications in deeply nested dataclasses is generally very frustrating. The three alternatives all introduce a .replace(self, ...) method to dataclasses that's a bit more convenient than the traditional dataclasses.replace(obj, ...) API for shallow changes, but still becomes really cumbersome to use when dataclasses are nested. jax_dataclasses.copy_and_mutate() is introduced to address this.

  • Static field support. Parameters that should not be traced in JAX should be marked as static. This is supported in flax, tjax, and jax_dataclasses, but not chex.

  • Serialization. When working with flax, being able to serialize dataclasses is really handy. This is supported in flax.struct (naturally) and jax_dataclasses, but not chex or tjax.

Misc

This code was originally written for and factored out of jaxfg, where Nick Heppert provided valuable feedback!

Comments
  • Fix infinite loop for cycles in pytrees

    Fix infinite loop for cycles in pytrees

    I have a rather big dataclass to describe a robot model, that includes a graph of links and a list of joints. Each node of the graph references the parent link and all the child links. Each joint object references its parent and child links.

    When I try to copy_and_mutate any of these objects, maybe due to all this nesting, an infinite loop occurs. I suspect that the existing logic tries to unfreeze all the leafs of the pytree, but the high interconnection and the properties of mutable Python types lead to a never ending unfreezing process.

    This PR addresses this edge case by storing the list of IDs of objects already unfreezed. It solves my problem, and it should not add any noticeable performance degradation.

    cc @brentyi

    opened by diegoferigo 10
  • Delayed initialisation of static fields

    Delayed initialisation of static fields

    First of all, thank you for the amazing library! I have recently discovered jax_dataclasses and I have decided to port my messy JAX functional code to a more organised object-oriented code based on jax_dataclasses.

    In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.

    Here is a simple example, where the attribute _sum is a derived static field that depends on the constant value of the array a.

    import jax
    import jax.numpy as jnp
    import jax_dataclasses as jdc
    
    @jdc.pytree_dataclass()
    class PyTreeDataclass:
        a: jnp.ndarray
        _sum: int = jdc.static_field(init=False, repr=False)
    
        def __post_init__(self):
            object.__setattr__(self, "_sum", self.a.sum().item())
    
    def print_pytree(obj):
        print(obj._sum)
    
    obj = PyTreeDataclass(jnp.arange(4))
    print_pytree(obj)
    jax.jit(print_pytree)(obj)
    

    The non-jitted version works, but when print_pytree is jitted I get the following error.

    File "jax_dataclasses_issue.py", line 14, in __post_init__
        object.__setattr__(self, "_sum", self.a.sum().item())
    AttributeError: 'bool' object has no attribute 'sum'
    

    Is there a way to compute in the __post_init__ the value of static fields not initialized in __init__ that depend on jnp.ndarray attributes of the dataclass?

    opened by lucagrementieri 4
  • `jax.tree_leaves` is deprecated

    `jax.tree_leaves` is deprecated

    The file jax_dataclasses/_copy_and_mutate.py raises many warnings complaining a deprecated function.

    FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    
    opened by lucagrementieri 1
  • Use jaxtyping to enrich type annotations

    Use jaxtyping to enrich type annotations

    I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.

    jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.

    Do you think that it could be added to jax_dataclasses?

    opened by lucagrementieri 4
  • Serialization of static fields?

    Serialization of static fields?

    Thanks for the handy library!

    I have a pytree_dataclass that contains a few static_fields that I would like to have serialized by the facilities in flax.serialize. I noticed that jax_dataclasses.asdict handles these, but that flax.serialization.to_state_dict and flax.serialization.to_bytes both ignore them. What is the correct way (if any) to have these fields included in flax's serialization? Should I be using another technique?

    import jax_dataclasses as jdc
    from jax import numpy as jnp
    import flax.serialization as fs
    
    
    @jdc.pytree_dataclass
    class Demo:
        a: jnp.ndarray = jnp.ones(3)
        b: bool = jdc.static_field(default=False)
    
    
    demo = Demo()
    print(f'{jdc.asdict(demo) = }')
    print(f'{fs.to_state_dict(demo) = }')
    print(f'{fs.from_bytes(Demo, fs.to_bytes(demo)) = }')
    
    # jdc.asdict(demo) = {'a': array([1., 1., 1.]), 'b': False}
    # fs.to_state_dict(demo) = {'a': DeviceArray([1., 1., 1.], dtype=float64)}
    # fs.from_bytes(Demo, fs.to_bytes(demo)) = {'a': array([1., 1., 1.])}
    

    Thanks in advance!

    opened by erdmann 3
Releases(v1.5.1)
Owner
Brent Yi
Brent Yi
Implementation of " SESS: Self-Ensembling Semi-Supervised 3D Object Detection" (CVPR2020 Oral)

SESS: Self-Ensembling Semi-Supervised 3D Object Detection Created by Na Zhao from National University of Singapore Introduction This repository contai

125 Dec 23, 2022
[SIGGRAPH Asia 2019] Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning

AGIS-Net Introduction This is the official PyTorch implementation of the Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning. paper | suppl

Yue Gao 102 Jan 02, 2023
GT China coal model

GT China coal model The full version of a China coal transport model with a very high spatial reslution. What it does The code works in a few steps: T

0 Dec 13, 2021
MPLP: Metapath-Based Label Propagation for Heterogenous Graphs

MPLP: Metapath-Based Label Propagation for Heterogenous Graphs Results on MAG240M Here, we demonstrate the following performance on the MAG240M datase

Qiuying Peng 10 Jun 28, 2022
Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models

Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models Abstract Many applications of generative models rely on the marginali

Stanford Intelligent Systems Laboratory 9 Jun 06, 2022
Fast image augmentation library and an easy-to-use wrapper around other libraries

Albumentations Albumentations is a Python library for image augmentation. Image augmentation is used in deep learning and computer vision tasks to inc

11.4k Jan 09, 2023
Code for the CIKM 2019 paper "DSANet: Dual Self-Attention Network for Multivariate Time Series Forecasting".

Dual Self-Attention Network for Multivariate Time Series Forecasting 20.10.26 Update: Due to the difficulty of installation and code maintenance cause

Kyon Huang 223 Dec 16, 2022
Official MegEngine implementation of CREStereo(CVPR 2022 Oral).

[CVPR 2022] Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation This repository contains MegEngine implementation of ou

MEGVII Research 309 Dec 30, 2022
[NeurIPS'21] Projected GANs Converge Faster

[Project] [PDF] [Supplementary] [Talk] This repository contains the code for our NeurIPS 2021 paper "Projected GANs Converge Faster" by Axel Sauer, Ka

798 Jan 04, 2023
This is a JAX implementation of Neural Radiance Fields for learning purposes.

learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi

Alex Nichol 62 Dec 20, 2022
HybridNets: End-to-End Perception Network

HybridNets: End2End Perception Network HybridNets Network Architecture. HybridNets: End-to-End Perception Network by Dat Vu, Bao Ngo, Hung Phan 📧 FPT

Thanh Dat Vu 370 Dec 29, 2022
GDSC-ML Team Interview Task

GDSC-ML-Team---Interview-Task Task 1 : Clean or Messy room In this task we have to classify the given test images as clean or messy. - Link for datase

Aayush. 1 Jan 19, 2022
Motion and Shape Capture from Sparse Markers

MoSh++ This repository contains the official chumpy implementation of mocap body solver used for AMASS: AMASS: Archive of Motion Capture as Surface Sh

Nima Ghorbani 135 Dec 23, 2022
This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT).

Dynamic-Vision-Transformer (Pytorch) This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT). Not All Ima

210 Dec 18, 2022
Accelerated NLP pipelines for fast inference on CPU and GPU. Built with Transformers, Optimum and ONNX Runtime.

Optimum Transformers Accelerated NLP pipelines for fast inference 🚀 on CPU and GPU. Built with 🤗 Transformers, Optimum and ONNX runtime. Installatio

Aleksey Korshuk 115 Dec 16, 2022
OpenCVのGrabCut()を利用したセマンティックセグメンテーション向けアノテーションツール(Annotation tool using GrabCut() of OpenCV. It can be used to create datasets for semantic segmentation.)

[Japanese/English] GrabCut-Annotation-Tool GrabCut-Annotation-Tool.mp4 OpenCVのGrabCut()を利用したアノテーションツールです。 セマンティックセグメンテーション向けのデータセット作成にご使用いただけます。 ※Grab

KazuhitoTakahashi 30 Nov 18, 2022
Using Tensorflow Object Detection API to detect Waymo open dataset

Waymo-2D-Object-Detection Using Tensorflow Object Detection API to detect Waymo open dataset Result CenterNet Training Loss SSD ResNet Training Loss C

76 Dec 12, 2022
Suite of 500 procedurally-generated NLP tasks to study language model adaptability

TaskBench500 The TaskBench500 dataset and code for generating tasks. Data The TaskBench dataset is available under wget http://web.mit.edu/bzl/www/Tas

Belinda Li 20 May 17, 2022
The official implementation for "FQ-ViT: Fully Quantized Vision Transformer without Retraining".

FQ-ViT [arXiv] This repo contains the official implementation of "FQ-ViT: Fully Quantized Vision Transformer without Retraining". Table of Contents In

132 Jan 08, 2023
Transparent Transformer Segmentation

Transparent Transformer Segmentation Introduction This repository contains the data and code for IJCAI 2021 paper Segmenting transparent object in the

谢恩泽 140 Jan 02, 2023