Jaxtorch (a jax nn library)

Related tags

Deep Learningjaxtorch
Overview

Jaxtorch (a jax nn library)

This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).

The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:

import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch

# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
    # They can accept any constructor parameters
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # Parameters are represented by a Param type, which identifies
        # them, and specifies how to initialize them.
        self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
        assert type(self.weight) is jaxtorch.Param
        if bias:
            self.bias = jaxtorch.init.zeros(out_features)
        else:
            self.bias = None

    # The forward function accepts cx, a Context object as the first argument
    # always. This provides random number generation as well as the parameters.
    def forward(self, cx: jaxtorch.Context, x):
        # Parameters are looked up in the context using the stored identifier.
        y = x @ jnp.transpose(cx[self.weight])
        if self.bias:
            y = y + cx[self.bias]
        return y

model = Linear(3, 3)

# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))

# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'

def loss(params, key):
    cx = jaxtorch.Context(params, key)
    x = jnp.array([1.0,2.0,3.0])
    y = jnp.array([4.0,5.0,6.0])
    return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)

for _ in range(100):
    (loss, grad) = f_grad(params, jax.random.PRNGKey(0))
    params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08
Owner
nshepperd
nshepperd
Pixel-wise segmentation on VOC2012 dataset using pytorch.

PiWiSe Pixel-wise segmentation on the VOC2012 dataset using pytorch. FCN SegNet PSPNet UNet RefineNet For a more complete implementation of segmentati

Bodo Kaiser 378 Dec 30, 2022
This code is for eCaReNet: explainable Cancer Relapse Prediction Network.

eCaReNet This code is for eCaReNet: explainable Cancer Relapse Prediction Network. (Towards Explainable End-to-End Prostate Cancer Relapse Prediction

Institute of Medical Systems Biology 2 Jul 28, 2022
Any-to-any voice conversion using synthetic specific-speaker speeches as intermedium features

MediumVC MediumVC is an utterance-level method towards any-to-any VC. Before that, we propose SingleVC to perform A2O tasks(Xi → Ŷi) , Xi means utter

谷下雨 47 Dec 25, 2022
A visualization tool to show a TensorFlow's graph like TensorBoard

tfgraphviz tfgraphviz is a module to visualize a TensorFlow's data flow graph like TensorBoard using Graphviz. tfgraphviz enables to provide a visuali

44 Nov 09, 2022
Training DALL-E with volunteers from all over the Internet using hivemind and dalle-pytorch (NeurIPS 2021 demo)

Training DALL-E with volunteers from all over the Internet This repository is a part of the NeurIPS 2021 demonstration "Training Transformers Together

<a href=[email protected]"> 19 Dec 13, 2022
git《Joint Entity and Relation Extraction with Set Prediction Networks》(2020) GitHub:

Joint Entity and Relation Extraction with Set Prediction Networks Source code for Joint Entity and Relation Extraction with Set Prediction Networks. W

130 Dec 13, 2022
PyTorch Implementation of PortaSpeech: Portable and High-Quality Generative Text-to-Speech

PortaSpeech - PyTorch Implementation PyTorch Implementation of PortaSpeech: Portable and High-Quality Generative Text-to-Speech. Model Size Module Nor

Keon Lee 279 Jan 04, 2023
CodeContests is a competitive programming dataset for machine-learning

CodeContests CodeContests is a competitive programming dataset for machine-learning. This dataset was used when training AlphaCode. It consists of pro

DeepMind 1.6k Jan 08, 2023
Easy Parallel Library (EPL) is a general and efficient deep learning framework for distributed model training.

English | 简体中文 Easy Parallel Library Overview Easy Parallel Library (EPL) is a general and efficient library for distributed model training. Usability

Alibaba 185 Dec 21, 2022
Streamlit app demonstrating an image browser for the Udacity self-driving-car dataset with realtime object detection using YOLO.

Streamlit Demo: The Udacity Self-driving Car Image Browser This project demonstrates the Udacity self-driving-car dataset and YOLO object detection in

Streamlit 992 Jan 04, 2023
Implementation of CVAE. Trained CVAE on faces from UTKFace Dataset to produce synthetic faces with a given degree of happiness/smileyness.

Conditional Smiles! (SmileCVAE) About Implementation of AE, VAE and CVAE. Trained CVAE on faces from UTKFace Dataset. Using an encoding of the Smile-s

Raúl Ortega 3 Jan 09, 2022
Definition of a business problem according to Wilson Lower Bound Score and Time Based Average Rating

Wilson Lower Bound Score, Time Based Rating Average In this study I tried to calculate the product rating and sorting reviews more accurately. I have

3 Sep 30, 2021
Reinforcement Learning Theory Book (rus)

Reinforcement Learning Theory Book (rus)

qbrick 206 Nov 27, 2022
curl-impersonate: A special compilation of curl that makes it impersonate Chrome & Firefox

curl-impersonate A special compilation of curl that makes it impersonate real browsers. It can impersonate the four major browsers: Chrome, Edge, Safa

lwthiker 1.9k Jan 03, 2023
General Multi-label Image Classification with Transformers

General Multi-label Image Classification with Transformers Jack Lanchantin, Tianlu Wang, Vicente Ordóñez Román, Yanjun Qi Conference on Computer Visio

QData 154 Dec 21, 2022
Reviving Iterative Training with Mask Guidance for Interactive Segmentation

This repository provides the source code for training and testing state-of-the-art click-based interactive segmentation models with the official PyTorch implementation

Visual Understanding Lab @ Samsung AI Center Moscow 406 Jan 01, 2023
Re-implement CycleGAN in Tensorlayer

CycleGAN_Tensorlayer Re-implement CycleGAN in TensorLayer Original CycleGAN Improved CycleGAN with resize-convolution Prerequisites: TensorLayer Tenso

89 Aug 15, 2022
(CVPR 2022) Pytorch implementation of "Self-supervised transformers for unsupervised object discovery using normalized cut"

(CVPR 2022) TokenCut Pytorch implementation of Tokencut: Self-supervised Transformers for Unsupervised Object Discovery using Normalized Cut Yangtao W

YANGTAO WANG 200 Jan 02, 2023
A scikit-learn-compatible module for estimating prediction intervals.

MAPIE - Model Agnostic Prediction Interval Estimator MAPIE allows you to easily estimate prediction intervals (or prediction sets) using your favourit

588 Jan 04, 2023
P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks

P-tuning v2 P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks An optimized prompt tuning strategy for sma

THUDM 540 Dec 30, 2022