Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Related tags

Deep Learningobjax
Overview

Objax

Tutorials | Install | Documentation | Philosophy

This is not an officially supported Google product.

Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX -- a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.

This is the developer repository of Objax, there is very little user documentation here, for the full documentation go to objax.readthedocs.io.

You can find READMEs in the subdirectory of this project, for example:

User installation guide

You install Objax using pip as follows:

pip install --upgrade objax

Objax supports GPUs but assumes that you already have some version of CUDA installed. Here are the extra steps:

# Update accordingly to your installed CUDA version
CUDA_VERSION=11.0
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`

Useful environment configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Testing your installation

You can test your installation by running the code below:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal(shape=(100, 4))
m = objax.nn.Linear(nin=4, nout=5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal(shape=(100, 3, 32, 32))
m = objax.nn.Conv2D(nin=3, nout=4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Runing code examples

Clone the code repository:

git clone https://github.com/google/objax.git
cd objax/examples

Citing Objax

To cite this repository:

@software{objax2020github,
  author = {{Objax Developers}},
  title = {{Objax}},
  url = {https://github.com/google/objax},
  version = {1.2.0},
  year = {2020},
}

Developer documentation

Here is information about development setup and a guide on adding new code.

Comments
  • More control over var/module namespace.

    More control over var/module namespace.

    I got my first 'hello world' model experiment working w/ Objax. I adapted my PyTorch EfficientNet impl. Overall pretty smooth, currently wrapping Conv2d so I can get the padding I want.

    One thing that stuck out after inspecting the model, the var namespace is a mess. An aspect of modelling that I value highly is the ability to have sensible checkpoint/var maps to work with. I often end up dealing with conversions between frameworks, exports for mobile or embedded targets and having your vars (parameters) sensibly named, and often being able to control those names in the originating framework is important.

    Any thoughts on improving this? The current name/scoping mechanism forces the inclusion of the Module class names, is that necessary? Shouldn't attr names through the tree be enough for uniqueness?

    Also, there is no ability to specify names for modules in sequential containers. I use this quite often for frameworks that have it. Sometimes I don't care much (long list of block repeats, 0..n is fine), but for finer grained blocks I like to know what conv is what by looking at the var names. '0.b, o.w' etc isn't very useful.

    I'll post an example of the var keys below, and comparison point for pytorch.

    feature request 
    opened by rwightman 29
  • upsample2d function rough draft

    upsample2d function rough draft

    Hi Team, i am pretty new to contributing in opensource projects. Please have a review of the upsample2d function and let me know of anything that is required or should be changed. the function is added in objax.function.ops module.

    opened by naruto-raj 22
  • Add mean squared logarithmic loss function

    Add mean squared logarithmic loss function

    1. Added mean squared logarithmic loss function
    2. In the CONTRIBUTIONS.md file, there is no mention of code-style. So, I am using 4-spaces.
    3. I haven't formatted the code using black as there is no mention of any formatter as well.

    I will add the tests once the above points are clear

    opened by AakashKumarNain 16
  • Initial dot product attention

    Initial dot product attention

    Adds attention, per #61 So, first I'm really sorry about taking so long, but college got complicated in the pandemic and I wasted a lot of time getting organized. Also, Attention is a quite general concept, and even implementations of the same type of attention differ significantly (haiku, flax) So @david-berthelot and @aterzis-google I would like to ask a few questions just to make sure my implementation is going in the right direction

    1. I think I will implement a dot product attention, a multi-head attention and a masked attention, is that ok?
    2. What do you think of the dot product attention implementation? What do you think I need to change? Thanks for the patience and opportunity.
    opened by joaogui1 12
  • "objax.variable.VarCollection is not a valid JAX type" when creating a custom optimizer

    Hi, I wish to create a custom optimizer to replace the opt(lr=lr, grads=g) line in the example https://github.com/google/objax/blob/master/examples/classify/img/cifar10_simple.py

    Instead, I replaced it with

    for grad, p in zip(g, model_vars):
          p.value -= lr * grad   
    

    and then supplied model.vars() as an argument to train_op. However, I received an error: objax.variable.VarCollection is not a valid JAX type. Can someone help me with this issue? Here is a minimal working example which reproduces the error.

    import random
    import numpy as np
    import tensorflow as tf
    from objax.zoo.wide_resnet import WideResNet
    
    # Data
    (X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
    X_train = X_train.transpose(0, 3, 1, 2) / 255.0
    X_test = X_test.transpose(0, 3, 1, 2) / 255.0
    
    # Model
    model = WideResNet(nin=3, nclass=10, depth=28, width=2)
    #opt = objax.optimizer.Adam(model.vars())
    predict = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)),
                        model.vars())
    # Losses
    def loss(x, label):
        logit = model(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
    
    gv = objax.GradValues(loss, model.vars())
    
    def train_op(x, y, model_vars, lr):
        g, v = gv(x, y)
        for grad, p in zip(g, model_vars):
          p.value -= lr * grad   
        return v
    
    
    # gv.vars() contains the model variables.
    train_op = objax.Jit(train_op, gv.vars()) #I deleted opt.vars()
    
    for epoch in range(30):
        # Train
        loss = []
        sel = np.arange(len(X_train))
        np.random.shuffle(sel)
        for it in range(0, X_train.shape[0], 64):
            loss.append(train_op(X_train[sel[it:it + 64]], Y_train[sel[it:it + 64]].flatten(), model.vars(), 4e-3 if epoch < 20 else 4e-4)) #I added model.vars() 
    
    opened by RXZ2020 11
  • Enforcing positivity (or other transformations) of TrainVars

    Enforcing positivity (or other transformations) of TrainVars

    Hi,

    Is it possible to declare constraints on trainable variables, e.g. forcing them to be positive via an exponential or softplus transformation?

    In an ideal world, we would be able to write something like: self.variance = objax.TrainVar(np.array(1.0), transform=positive)

    Thanks,

    Will

    p.s. thanks for the great work on objax so far, it's a pleasure to use.

    opened by wil-j-wil 10
  • Training state as a Module attribute

    Training state as a Module attribute

    As mentioned in a Twitter thread, I am curious about the decision to propagate training state through the call() chain. From my perspective this approach adds more boilperplate code, and more chance of making a mistake (not propagating the state to a few instances of a module with a BN or dropout layer, etc). If the state changed every call like the input data, it would make more sense to pass it with every forward, but I can't think of cases where that is common? For small models it doesn't make much difference, but as they grow with more depth and breadth of submodules, the extra args are more noticeable.

    I feel one of the major benefits of an OO abstraction for NN is being able to push some attributes like this into the class structure vs forcing it to be forwarded through every call in a functional manner. I sit in the middle ground (pragmatic) of OO vs functional. Hidden state can be problematics, but worth it if it keeps interfaces clean.

    Besides TF/Keras, most DL libs managetraining state as module attr or some sort of context

    • PyTorch - nn.Module has a self.training attribute, recursively set on train()/eval() calls on the model/modules - https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval
    • MxNet Gluon - a context manager sets scope with autograd.train_mode() with autograd.predict_mode() - https://gluon.mxnet.io/chapter03_deep-neural-networks/mlp-dropout-gluon.html
    • Swift for TF - a thread-local context holds learningPhase - https://www.tensorflow.org/swift/api_docs/Structs/Context

    It should be noted that Swift for TF started out Keras and objax like with the training state passed through call().

    Disclaimer: I like PyTorch, I do quite a bit of work with that framework. It's not perfect but I feel they really did a good job in terms of interface, usibility, evolution of the API. I've read some other comments here and acknowledge the 'we don't want to be like framework/lib X, or Y just because. If you disagree go fork yourself'. Understood, any suggestions I make are not just to be like X, but to bring elemtents of X that work really well to improve this library.

    I currently maintain some PyTorch model collections, https://github.com/rwightman/pytorch-image-models and https://github.com/rwightman/efficientdet-pytorch as examples. I'm running into a cost ($$) wall with experiments supporting my OS work and experiments re GPU. TPU costing is starting to look far more attractive. PyTorch XLA is not proving to be a great option but JAX with a productive interface looks like it could be a winning solution with even more flexibility .

    I'm willing to contribute code for changes like this, but at this point it's matter of design philosophy :)

    opened by rwightman 9
  • Implementing 2 phases DP-SGD

    Implementing 2 phases DP-SGD

    This PR implements a two-phase algorithm for per-sample gradient clipping with the goal of improving memory efficiency for the training of private deep models. The two steps are: (1) accumulate the norms of the gradient per sample and (2) use those norm values to perform a weighted backward pass that is equivalent to per-sample clipping. The user can choose whether to use this new algorithm or the currently implemented one through a boolean argument.

    The unit-tests have been adapted to check results for both algorithms.

    Let me know if this fits well!

    opened by lberrada 7
  • Give better error message when calling Parallel() without replicate()

    Give better error message when calling Parallel() without replicate()

    Currently if you forget to call replicate() on a Parallel module, it dies somewhere in JaX land in between the 5th and 6th circles of hell. This error makes it possible to understand what's going on and find your way back.

    opened by carlini 7
  • Naming of the `GradValues` function

    Naming of the `GradValues` function

    If I understand right, GradValues essentially does two things: computing gradients and computing model final values.

    So why not split it into two functions? Or if we keep the current form, could we name it GradAndValuesFn? Just thinking this is a prominent function and want to keep it the easiest for people beginning to use the framework. An easy name as fit() and predict() made scikit-learn.

    opened by jli05 6
  • Explicit padding mode

    Explicit padding mode

    It looks like objax currently limits padding to one of VALID or SAME. This prevents the ability to use explicit padding and would prevent compatibility with models from PyTorch, Gluon that only support explicit (symmetric) padding without adding extra Pad layers to the model.

    It'd be nice to at minimum add the ability to support TF style explicit padding (specify both sides of every dim), the underlying jax conv impl is able to receive a [[0, 0], [pad_beg, pad_end],[pad_beg, pad_end], [0, 0]] spec like other low level TF conv.

    Even nicer would be a simplificed, per-spatial dim symmetric values like PyTorch, Gluon [pad_h, pad_w] or just pad . My default for most 2D convnets in PyTorch is to use pad = ((stride - 1) + dilation * (kernel_size - 1)) // 2, which results in a 'same-ish' padding value. This can always be done on top of the full low/high padding sequence above.

    Some TF models explicitly work around the limitations of SAME padding. By limitations, I mean the fact that you end up with input dependent padding that can be aysmmetric and shift your feature maps relative to each other in a manner that varies as you change your input size. https://github.com/tensorflow/models/blob/146a37c6663e4a249e02d3dff0087b576e3dc3a1/research/deeplab/core/xception.py#L81-L201

    Possible interfaces:

    • padding : Union[ConvPadding, Sequence[Tuple[int, int]]] (like conv_general_dilated but with the enum for valid/same)

    • Add more modes the enum and associated values for those that need it via a dataclass

    class PaddingType(enum.Enum):
      """An Enum holding the possible padding values for convolution modules."""
        SAME = 'SAME'
        VALID = 'VALID'
        RAW = 'RAW'  # specify padding as seq of high/low tuples
        SYM = 'SYM'  # specify symmetric padding for spatial dim as tuple for H, W or single int
    
    @dataclass
    class Padding:
        type: PaddingType = PaddingType.SAME
        value: Union[Sequence[Tuple[int, int]], Tuple[int, int], int] = None
    
        @classmethod
        def same(cls):
            return Padding(PaddingType.SAME)
    
        @classmethod
        def valid(cls):
            return Padding(PaddingType.VALID)
    
        @classmethod
        def raw(cls, value: Sequence[Tuple[int, int]]):
            return Padding(PaddingType.RAW, value=value)
    
        @classmethod
        def sym(cls, value: Union[Tuple[int, int], int]):
            return Padding(PaddingType.SYM, value=value)
    
    feature request 
    opened by rwightman 6
  • `objax.variable.VarCollection.update` not compliant with key-value assignment

    `objax.variable.VarCollection.update` not compliant with key-value assignment

    Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

    I'm trying to load some VarCollection and/or Dict[str, jnp.DeviceArray] params into the model.vars() which is a VarCollection class, and I can do so by:

    for key, value in new_params.items():
        model.vars()[key].assign(value)
    

    But I'd expect objax.variable.VarCollection.update to work the same way e.g.

    model.vars().update(new_params)
    

    And the later doesn't work while the first one does, not sure if it's because that's not the intended behavior for VarCollection.update or if I'm doing anything wrong... But just the first one works, which for the moment is fine for what I need, but wanted to mention this just in case there's something not working as expected.

    opened by alvarobartt 1
  • `objax.variable.VarCollection.update` fails when passing `Dict[str, Any]`

    `objax.variable.VarCollection.update` fails when passing `Dict[str, Any]`

    Hi everyone! Thanks for the awesome work with objax and the JAX environment, and happy holidays!

    I was playing around for objax for a bit, and realized that if you try to update the model.vars() which is a VarCollection using the VarCollection.update method overwriting the default dict.update method, if what you pass to the function is a Python dictionary and not a VarCollection it fails, as it's being cast into a Python list, and then we're trying to loop over the items of a list as if it was a Python dictionary, so it throws a ValueError: too many values to unpack (expected 2).

    https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311-L318

    Is this intended? Shouldn't VarCollection.update just loop over classes that allow .items()?

    opened by alvarobartt 0
  • Update nn.rst

    Update nn.rst

    The channel number for 'in' is currently set as c which is incorrect because c is referring to the output channel number. Instead this needs to be set as t (which is the variable that iterates over the input channel numbers). in[n,c,i+h,j+w] should be changed to in[n,t,i+h,j+w]

    opened by divyas248 1
  • pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus

    pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus

    Hi, I've noticed a problem, where I'd like to ask for your expertise. I'm not entirely sure if it is an objax problem or rather a Jax problem under the hood, but as it is triggered by objax commands I'll post it here.

    Description

    In particular, when combining objax.Parallel and objax.functional.pmean (as done in this tutorial) I encounter problems with more than 2 GPUs (with 2 GPUs it works fine). It results in a deadlock situation, where nothing happens anymore. If I understand the tutorial correctly, the pmean is necessary to average the gradients of all cards.

    Minimal reproducible example

    import objax
    import numpy as np
    from objax.zoo.resnet_v2 import ResNet18
    from jax import numpy as jnp, device_count
    from tqdm import tqdm
    
    
    if __name__ == "__main__":
        print(f"Num devices: {device_count()}")
        model = ResNet18(3, 1)
        opt = objax.optimizer.SGD(model.vars())
    
        @objax.Function.with_vars(model.vars())
        def loss(x, label):
            return objax.functional.loss.mean_squared_error(
                model(x, training=True), label
            ).mean()
    
        gv = objax.GradValues(loss, model.vars())
    
        train_vars = model.vars() + gv.vars() + opt.vars()
    
        @objax.Function.with_vars(train_vars)
        def train_op(
            image_batch,
            label_batch,
        ):
    
            grads, loss = gv(image_batch, label_batch)
            # grads = objax.functional.parallel.pmean(grads) # this line
            # loss = objax.functional.parallel.pmean(loss) # and this line
            loss = loss[0]
            opt(1e-3, grads)
            return loss, grads
    
        train_op = objax.Parallel(train_op, reduce=jnp.mean, vc=train_vars)
    
        with (train_vars).replicate():
            for _ in tqdm(range(10), total=10):
                data = jnp.array(np.random.randn(512, 3, 224, 224))
                label = jnp.zeros((512, 1))
                loss, grads = train_op(data, label)
    
    

    Whenever you comment in the two lines with pmean the program gets stuck. However, if I understood it correctly, this is necessary to get the average of the gradients over all cards.

    Error traces

    As with most deadlock bugs you don't get an error stack trace. However, I have two clues that I've found so far. One is that if this is uncommented, the following appears:

    2022-08-22 14:55:46.462557: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
    2022-08-22 14:55:48.543291: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:36] Thread is unstuck! Warning above was a false-positive. Perhaps the timeout is too short.
    

    The other is that if I manually interrupt it with ctrl+c I got this lengthy stacktrace

    Setup

    We use 4 NVIDIA A40 GPUs with CUDA Version 11.7 (Driver Version 515.65.01), cudnn 8.2.1.32, jax version 0.3.15, objax version 1.6.0

    opened by a1302z 3
Releases(v1.6.0)
  • v1.6.0(Feb 1, 2022)

  • v1.4.0(Apr 1, 2021)

    • Added prototype of ducktyping of Objax variables as JAX arrays
    • Added prototype of automatic variable tracing
    • Added learning rate scheduler
    • Various bugfixes
    Source code(tar.gz)
    Source code(zip)
  • v1.3.1(Feb 3, 2021)

  • v1.3.0(Jan 29, 2021)

    • Feature: Improved error messages overall
    • Feature: Improved BatchNorm numerical stability
    • Feature: Objax2Tf for serving objax using TensorFlow
    • Feature: New API objax.optimizer.ExponentialMovingAverageModule for easy moving average of a model
    • Feature: Automatic broadcasting of scalars for objax.Parallel
    • Feature: New optimizer: LARS
    • Feature: New API added to functional (lax.scan)
    • Feature: Modules can be printed to nicely readable text now (repr)
    • Feature: New interpolate API (for images)
    • Bugfix: make objax.Sequential work with latest JAX
    Source code(tar.gz)
    Source code(zip)
  • v1.2.0(Nov 2, 2020)

    • Feature: Improved error messages.

    • Feature: Extended syntax: allow assigning TrainVar without TrainRef for direction experimentation.

    • Feature: Extended padding options or pad and convolution.

    • Feature: Modified ResNet_V2 to be Keras compatible.

    • Feature: Defaults can be overridden in call for Adam, Momentum.

    • BugFix: Layer norm initialization in GPT-2.

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
A python interface for training Reinforcement Learning bots to battle on pokemon showdown

The pokemon showdown Python environment A Python interface to create battling pokemon agents. poke-env offers an easy-to-use interface for creating ru

Haris Sahovic 184 Dec 30, 2022
Generating synthetic mobility data for a realistic population with RNNs to improve utility and privacy

lbs-data Motivation Location data is collected from the public by private firms via mobile devices. Can this data also be used to serve the public goo

Alex 11 Sep 22, 2022
Totally Versatile Miscellanea for Pytorch

Totally Versatile Miscellania for PyTorch Thomas Viehmann [email protected] Thi

Thomas Viehmann 428 Dec 28, 2022
[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Fudan Zhang Vision Group 897 Jan 05, 2023
An investigation project for SISR.

SISR-Survey An investigation project for SISR. This repository is an official project of the paper "From Beginner to Master: A Survey for Deep Learnin

Juncheng Li 79 Oct 20, 2022
Source code for "Taming Visually Guided Sound Generation" (Oral at the BMVC 2021)

Taming Visually Guided Sound Generation • [Project Page] • [ArXiv] • [Poster] • • Listen for the samples on our project page. Overview We propose to t

Vladimir Iashin 226 Jan 03, 2023
Contains modeling practice materials and homework for the Computational Neuroscience course at Okinawa Institute of Science and Technology

A310 Computational Neuroscience - Okinawa Institute of Science and Technology, 2022 This repository contains modeling practice materials and homework

Sungho Hong 1 Jan 24, 2022
Code for Environment Inference for Invariant Learning (ICML 2020 UDL Workshop Paper)

Environment Inference for Invariant Learning This code accompanies the paper Environment Inference for Invariant Learning, which appears at ICML 2021.

Elliot Creager 40 Dec 09, 2022
Manifold Alignment for Semantically Aligned Style Transfer

Manifold Alignment for Semantically Aligned Style Transfer [Paper] Getting Started MAST has been tested on CentOS 7.6 with python = 3.6. It supports

35 Nov 14, 2022
Pointer networks Tensorflow2

Pointer networks Tensorflow2 原文:https://arxiv.org/abs/1506.03134 仅供参考与学习,内含代码备注 环境 tensorflow==2.6.0 tqdm matplotlib numpy 《pointer networks》阅读笔记 应用场景

HUANG HAO 7 Oct 27, 2022
Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding

🍐 quince Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding 🍐 Installation $ git clone

Andrew Jesson 19 Jun 23, 2022
Segmentation for medical image.

EfficientSegmentation Introduction EfficientSegmentation is an open source, PyTorch-based segmentation framework for 3D medical image. Features A whol

68 Nov 28, 2022
Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers.

Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers. It contains purchases, recurring

Ayodeji Yekeen 1 Jan 01, 2022
Deep Networks with Recurrent Layer Aggregation

RLA-Net: Recurrent Layer Aggregation Recurrence along Depth: Deep Networks with Recurrent Layer Aggregation This is an implementation of RLA-Net (acce

Joy Fang 21 Aug 16, 2022
Reinforcement Learning for Automated Trading

Reinforcement Learning for Automated Trading This thesis has been realized for the obtention of the Master's in Mathematical Engineering at the Polite

Pierpaolo Necchi 80 Jun 19, 2022
sktime companion package for deep learning based on TensorFlow

NOTE: sktime-dl is currently being updated to work correctly with sktime 0.6, and wwill be fully relaunched over the summer. The plan is Refactor and

sktime 573 Jan 05, 2023
Implementation of Nalbach et al. 2017 paper.

Deep Shading Convolutional Neural Networks for Screen-Space Shading Our project is based on Nalbach et al. 2017 paper. In this project, a set of buffe

Marcel Santana 17 Sep 08, 2022
Awesome-AI-books - Some awesome AI related books and pdfs for learning and downloading

Awesome AI books Some awesome AI related books and pdfs for downloading and learning. Preface This repo only used for learning, do not use in business

luckyzhou 1k Jan 01, 2023
(NeurIPS 2021) Realistic Evaluation of Transductive Few-Shot Learning

Realistic evaluation of transductive few-shot learning Introduction This repo contains the code for our NeurIPS 2021 submitted paper "Realistic evalua

Olivier Veilleux 14 Dec 13, 2022
Automatically creates genre collections for your Plex media

Plex Auto Genres Plex Auto Genres is a simple script that will add genre collection tags to your media making it much easier to search for genre speci

Shane Israel 63 Dec 31, 2022