JMP is a Mixed Precision library for JAX.

Related tags

Machine Learningjmp
Overview

Mixed precision training in JAX

Test status PyPI version

Installation | Examples | Policies | Loss scaling | Citing JMP | References

Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.

This library implements support for mixed precision training in JAX by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as Haiku) can integrate with jmp and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules).

All code examples below assume the following:

import jax
import jax.numpy as jnp
import jmp

half = jnp.float16  # On TPU this should be jnp.bfloat16.
full = jnp.float32

Installation

JMP is written in pure Python, but depends on C++ code via JAX and NumPy.

Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install JMP using pip:

$ pip install git+https://github.com/deepmind/jmp

Examples

You can find a fully worked JMP example in Haiku which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third.

Policies

A mixed precision policy encapsulates the configuration in a mixed precision experiment.

# Our policy specifies that we will store parameters in full precision but will
# compute and return output in half precision.
my_policy = jmp.Policy(compute_dtype=half,
                       param_dtype=full,
                       output_dtype=half)

The policy object can be used to cast pytrees:

def layer(params, x):
  params, x = my_policy.cast_to_compute((params, x))
  w, b = params
  y = x @ w + b
  return my_policy.cast_to_output(y)

params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half

You can replace the output type of a given policy:

my_policy = my_policy.with_output_dtype(full)

You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment:

my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
float16 = jmp.get_policy("float16")  # Everything in f16.
half = jmp.get_policy("half")        # Everything in half (f16 or bf16).

Loss scaling

When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with float16 and less important for bfloat16. See the NVIDIA mixed precision user guide [1] for more details.

The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by S and 1/S respectively.

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.
  params = apply_optimizer(params, grads)
  return params

loss_scale = jmp.StaticLossScale(2 ** 15)
for _ in range(num_steps):
  params = train_step(params, loss_scale, ...)

The appropriate value for S depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of S that does not introduce overflow during backprop. NVIDIA [1] recommend computing statistics about the gradients of your model (in full precision) and picking S such that its product with the maximum norm of your gradients is below 65,504.

We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for S that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%).

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.

  # You definitely want to skip non-finite updates with the dynamic loss scale,
  # but you might also want to consider skipping them when using a static loss
  # scale if you experience NaN's when training.
  skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)

  if skip_nonfinite_updates:
    grads_finite = jmp.all_finite(grads)
    # Adjust our loss scale depending on whether gradients were finite. The
    # loss scale will be periodically increased if gradients remain finite and
    # will be decreased if not.
    loss_scale = loss_scale.adjust(grads_finite)
    # Only apply our optimizer if grads are finite, if any element of any
    # gradient is non-finite the whole update is discarded.
    params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)
  else:
    # With static or no loss scaling just apply our optimizer.
    params = apply_optimizer(params, grads)

  # Since our loss scale is dynamic we need to return the new value from
  # each step. All loss scales are `PyTree`s.
  return params, loss_scale

loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
for _ in range(num_steps):
  params, loss_scale = train_step(params, loss_scale, ...)

In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue.

We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the jmp.LossScale API):

loss_scale = jmp.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1

Citing JMP

This repository is part of the DeepMind JAX Ecosystem, to cite JMP please use the DeepMind JAX Ecosystem citation.

References

[0] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740.

[1] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/.

Comments
  • Questions around speedup

    Questions around speedup

    Hi,

    Thanks for creating this amazing library!

    So from what I understood this is the minimal needed to benefit from the JMP speedup

    policy = jmp.Policy(param_dtype=jnp.float32, compute_dtype=jnp.float16, output_dtype=jnp.float16)
    ...
    # creating a network, and creating a loss_fn using the network
    data = policy.cast_to_compute(data)
    params = policy.cast_to_compute(params)
    
    grads = jax.grad(loss_fn)(data, params)
    ...
    grads = policy.cast_to_param(grads)
    

    The loss scale is only needed if we experience NaN or inf values. And thus we should see a difference in the time needed for a jax.grad operation when the data and params are in float16.

    Please correct me if I'm wrong or if I miss anything.

    Considering that, I have some questions:

    • Can we expect to see a 2x or any speedup if we time the jax.grad operation with (params, data) in float16 and in float32 or is the speedup only achieved at scale ?
    • Can we expect to see a 2x or any speedup with small networks and small experiments ? (say 2-layer nets and experiments that only need 2-3 minutes on a single gpu)
    • Can we expect to see a 2x or any speedup with all types of networks or only some with specific architectures ? (say a 50 layer MLP, Transformers, etc)
    • Is it necessary to apply the mixed precision policy to the network and thus to use Haiku for hk.mixed_precision.set_policy or is having the parameters and data in float16 sufficient to have a speedup even with a network created by us ?

    Thank you and have a nice day !

    opened by 1m1ne 3
  • Basic question about the use of my_policy

    Basic question about the use of my_policy

    Hi,

    Thanks for creating this amazing functionality!

    I have a basic question about the use of policy functions to set certain precision levels. As stated in the example of your README.md:

    def layer(params, x):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    I am new to JAX, but the first thing I learned is that JAX likes pure functions. Does the use of my_policy violate the pure functions paradigm?

    Should it become:

    def layer(params, x, my_policy):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    The function can be jitted by using partial.

    from functools import partial
    layer_compiled = jit(partial(layer, my_policy=my_policy))
    

    Fundamental question

    A more fundamental question (which I maybe need to ask at JAX ) is: How should functions as input to functions be handled in JAX?

    func_a(x,w,func):
      ...
    
    func_a_compiled = jit(partial(func_a, func=func_b))
    

    In order to jit this, I came up with the partial solution above. I assume your use of my_policy is valid, as you probably have more experience with JAX. But that creates some magic, which is undesirable for my use case. Is the jit(partial()) solution valid or is there a better way to handle functions as input to functions?

    Have a nice day! J

    opened by JSchuurmans 2
  •  jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    When trying to build my own wheel from the tar-ball of jmp-0.0.2 that has been uploaded to PyPI I get the following error:

    Collecting jmp==0.0.2
      Downloading jmp-0.0.2.tar.gz (13 kB)
        Running command python setup.py egg_info
        Traceback (most recent call last):
          File "<string>", line 1, in <module>
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 55, in <module>
            install_requires=_parse_requirements('requirements.txt'),
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 32, in _parse_requirements
            with open(requirements_txt_path) as fp:
        FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
    WARNING: Discarding https://files.pythonhosted.org/packages/7c/ba/a6bfcaeedca8551e2fb4054d1fd061a0dd97d26dd44002b3e92d13b51877/jmp-0.0.2.tar.gz#sha256=fdb5cec0d10aab4116c2770f24b2adf4f503fcfbb96ce8ef583e1879bdbf1b9b (from https://pypi.org/simple/jmp/). Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement jmp==0.0.2 (from versions: 0.0.1, 0.0.2)
    ERROR: No matching distribution found for jmp==0.0.2
    

    After downloading and extracting jmp-0.0.2.tar.gz from PyPI and looking at setup.py, I see that it contains lines that reference requirements.txt:

    install_requires=_parse_requirements('requirements.txt')
    

    however that file is not part of the source distribution and my build fails.

    Oliver

    P.S. Yes I know that there is a wheel for jmp-0.0.2 on PyPI, which I ended up using, but I'm using our Wheels_builder script that will recursively build wheels for dependencies for our systems. My point is that the source distribution of jmp 0.0.2 is incomplete as it lacks the information that is contained in the requirements file.

    opened by ostueker 1
  • Casting numpy array

    Casting numpy array

    I'll first start by saying that I've just starting with JAX, so I might be doing something wrong.

    When I run the following code:

    precision = jmp.get_policy('params=float32,compute=float16,output=float16')
    some_input = np.arange(15).reshape((5, 3))
    @jax.jit
    def some_function(some_input):
        some_input = precision.cast_to_compute(some_input)
        print(some_input.dtype)
        # prints float32
       return a_float16_compute_model(some_input) # fails
    

    It seems that at least on the first (tracing) run, the numpy array doesn't get cast to float16, maybe because it is being treated as a tree, here

    https://github.com/deepmind/jmp/blob/4b94370b8de29b79d6f840b09d1990b91c1afddd/jmp/_src/policy.py#L26

    Surprisingly, if I run

    some_input = np.arange(15).reshape((5, 3)).astype(precision.compute_dtype)
    print(some_input.dtype)
    
    #prints float16
    

    the cast succeeds. I think that the expected behavior is that precision.cast_to_compute(some_input) should return a numpy array with compute_dtype, but I might be missing something.

    opened by yardenas 1
  • [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    An upcoming change to JAX adds a @jit decorator around a number of array operators, in this case division (/). A side effect of that change is that large Python integer constants that overflow an int32 or int64 type may produce an error. The workaround is either to explicitly cast the large constants to a specific type (e.g., np.float64), or in this case we can just do the math in question in classic NumPy since it is computing test expectations.

    cla: yes 
    opened by copybara-service[bot] 0
  • Bump version to `0.0.3.dev`.

    Bump version to `0.0.3.dev`.

    Bump version to 0.0.3.dev.

    This is so if users install from GitHub we can see they are not using a stable version in bug reports:

    $ pip install git+https://github.com/deepmind/jmp
    $ python -c 'import jmp ; print(jmp.__version__)'
    0.0.3.dev
    
    cla: yes 
    opened by copybara-service[bot] 0
  • Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    We are not using 0.0.1 because this version was already used by the previous owner of the pypi package (https://pypi.org/project/jmp/0.0.1/). As such our releases will start at 0.0.2.

    cla: yes 
    opened by copybara-service[bot] 0
  • Replace jax.lax.select with jnp.where

    Replace jax.lax.select with jnp.where

    Thanks for the awesome work!

    This PR fixes an issue where jax.lax.select complains about dtypes not being equal when adjusting the DynamicLossScale.

    The exception's stack trace ends with:

    File ".../jmp/_src/loss_scale.py", line 147, in adjust
        loss_scale = jax.lax.select(
    TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
    

    My code looks similar to

    scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
    ...
    gradients, scale = gradient_fn(..., scale)
    gradients = scale.unscale(gradients)
    gradients_finite = jmp.all_finite(gradients)
    scale = scale.adjust(gradients_finite)  # This line throws the exception
    ...
    
    opened by nlsfnr 6
Releases(v0.0.2)
  • v0.0.2(Apr 15, 2021)

    Initial release of JMP.

    Changelog:

    • Add jmp.Policy abstraction and jmp.get_policy(..) factory.
    • Add jmp.LossScale and three implementations thereof (noop, static and dynamic).
    • Add various utilities (jmp.all_finite) to support common tasks in mixed precision codebases.
    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
100 Days of Machine and Deep Learning Code

💯 Days of Machine Learning and Deep Learning Code MACHINE LEARNING TOPICS COVERED - FROM SCRATCH Linear Regression Logistic Regression K Means Cluste

Tanishq Gautam 66 Nov 02, 2022
Machine Learning from Scratch

Machine Learning from Scratch Author: Shengxuan Wang From: Oregon State University Content: Building Machine Learning model from Scratch, without usin

ShawnWang 0 Jul 05, 2022
Python module for performing linear regression for data with measurement errors and intrinsic scatter

Linear regression for data with measurement errors and intrinsic scatter (BCES) Python module for performing robust linear regression on (X,Y) data po

Rodrigo Nemmen 56 Sep 27, 2022
A Powerful Serverless Analysis Toolkit That Takes Trial And Error Out of Machine Learning Projects

KXY: A Seemless API to 10x The Productivity of Machine Learning Engineers Documentation https://www.kxy.ai/reference/ Installation From PyPi: pip inst

KXY Technologies, Inc. 35 Jan 02, 2023
Machine learning model evaluation made easy: plots, tables, HTML reports, experiment tracking and Jupyter notebook analysis.

sklearn-evaluation Machine learning model evaluation made easy: plots, tables, HTML reports, experiment tracking, and Jupyter notebook analysis. Suppo

Eduardo Blancas 354 Dec 31, 2022
Backtesting an algorithmic trading strategy using Machine Learning and Sentiment Analysis.

Trading Tesla with Machine Learning and Sentiment Analysis An interactive program to train a Random Forest Classifier to predict Tesla daily prices us

Renato Votto 31 Nov 17, 2022
neurodsp is a collection of approaches for applying digital signal processing to neural time series

neurodsp is a collection of approaches for applying digital signal processing to neural time series, including algorithms that have been proposed for the analysis of neural time series. It also inclu

NeuroDSP 224 Dec 02, 2022
Add built-in support for quaternions to numpy

Quaternions in numpy This Python module adds a quaternion dtype to NumPy. The code was originally based on code by Martin Ling (which he wrote with he

Mike Boyle 531 Dec 28, 2022
Time series forecasting with PyTorch

Our article on Towards Data Science introduces the package and provides background information. Pytorch Forecasting aims to ease state-of-the-art time

Jan Beitner 2.5k Jan 02, 2023
A visual dataflow programming language for sklearn

Persimmon What is it? Persimmon is a visual dataflow language for creating sklearn pipelines. It represents functions as blocks, inputs and outputs ar

Álvaro Bermejo 194 Jan 04, 2023
Fit interpretable models. Explain blackbox machine learning.

InterpretML - Alpha Release In the beginning machines learned in darkness, and data scientists struggled in the void to explain them. Let there be lig

InterpretML 5.2k Jan 09, 2023
Neural Machine Translation (NMT) tutorial with OpenNMT-py

Neural Machine Translation (NMT) tutorial with OpenNMT-py. Data preprocessing, model training, evaluation, and deployment.

Yasmin Moslem 29 Jan 09, 2023
🤖 ⚡ scikit-learn tips

🤖 ⚡ scikit-learn tips New tips are posted on LinkedIn, Twitter, and Facebook. 👉 Sign up to receive 2 video tips by email every week! 👈 List of all

Kevin Markham 1.6k Jan 03, 2023
Python package for causal inference using Bayesian structural time-series models.

Python Causal Impact Causal inference using Bayesian structural time-series models. This package aims at defining a python equivalent of the R CausalI

Thomas Cassou 219 Dec 11, 2022
CrayLabs and user contibuted examples of using SmartSim for various simulation and machine learning applications.

SmartSim Example Zoo This repository contains CrayLabs and user contibuted examples of using SmartSim for various simulation and machine learning appl

Cray Labs 14 Mar 30, 2022
Cohort Intelligence used to solve various mathematical functions

Cohort-Intelligence-for-Mathematical-Functions About Cohort Intelligence : Cohort Intelligence ( CI ) is an optimization technique. It attempts to mod

Aayush Khandekar 2 Oct 25, 2021
Toolss - Automatic installer of hacking tools (ONLY FOR TERMUKS!)

Tools Автоматический установщик хакерских утилит (ТОЛЬКО ДЛЯ ТЕРМУКС!) Оригиналь

14 Jan 05, 2023
A modular active learning framework for Python

Modular Active Learning framework for Python3 Page contents Introduction Active learning from bird's-eye view modAL in action From zero to one in a fe

modAL 1.9k Dec 31, 2022
A collection of interactive machine-learning experiments: 🏋️models training + 🎨models demo

🤖 Interactive Machine Learning experiments: 🏋️models training + 🎨models demo

Oleksii Trekhleb 1.4k Jan 06, 2023
A machine learning web application for binary classification using streamlit

Machine Learning web App This is a machine learning web application for binary classification using streamlit options this application contains 3 clas

abdelhak mokri 1 Dec 20, 2021