JMP is a Mixed Precision library for JAX.

Related tags

Machine Learningjmp
Overview

Mixed precision training in JAX

Test status PyPI version

Installation | Examples | Policies | Loss scaling | Citing JMP | References

Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.

This library implements support for mixed precision training in JAX by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as Haiku) can integrate with jmp and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules).

All code examples below assume the following:

import jax
import jax.numpy as jnp
import jmp

half = jnp.float16  # On TPU this should be jnp.bfloat16.
full = jnp.float32

Installation

JMP is written in pure Python, but depends on C++ code via JAX and NumPy.

Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install JMP using pip:

$ pip install git+https://github.com/deepmind/jmp

Examples

You can find a fully worked JMP example in Haiku which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third.

Policies

A mixed precision policy encapsulates the configuration in a mixed precision experiment.

# Our policy specifies that we will store parameters in full precision but will
# compute and return output in half precision.
my_policy = jmp.Policy(compute_dtype=half,
                       param_dtype=full,
                       output_dtype=half)

The policy object can be used to cast pytrees:

def layer(params, x):
  params, x = my_policy.cast_to_compute((params, x))
  w, b = params
  y = x @ w + b
  return my_policy.cast_to_output(y)

params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half

You can replace the output type of a given policy:

my_policy = my_policy.with_output_dtype(full)

You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment:

my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
float16 = jmp.get_policy("float16")  # Everything in f16.
half = jmp.get_policy("half")        # Everything in half (f16 or bf16).

Loss scaling

When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with float16 and less important for bfloat16. See the NVIDIA mixed precision user guide [1] for more details.

The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by S and 1/S respectively.

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.
  params = apply_optimizer(params, grads)
  return params

loss_scale = jmp.StaticLossScale(2 ** 15)
for _ in range(num_steps):
  params = train_step(params, loss_scale, ...)

The appropriate value for S depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of S that does not introduce overflow during backprop. NVIDIA [1] recommend computing statistics about the gradients of your model (in full precision) and picking S such that its product with the maximum norm of your gradients is below 65,504.

We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for S that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%).

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.

  # You definitely want to skip non-finite updates with the dynamic loss scale,
  # but you might also want to consider skipping them when using a static loss
  # scale if you experience NaN's when training.
  skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)

  if skip_nonfinite_updates:
    grads_finite = jmp.all_finite(grads)
    # Adjust our loss scale depending on whether gradients were finite. The
    # loss scale will be periodically increased if gradients remain finite and
    # will be decreased if not.
    loss_scale = loss_scale.adjust(grads_finite)
    # Only apply our optimizer if grads are finite, if any element of any
    # gradient is non-finite the whole update is discarded.
    params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)
  else:
    # With static or no loss scaling just apply our optimizer.
    params = apply_optimizer(params, grads)

  # Since our loss scale is dynamic we need to return the new value from
  # each step. All loss scales are `PyTree`s.
  return params, loss_scale

loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
for _ in range(num_steps):
  params, loss_scale = train_step(params, loss_scale, ...)

In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue.

We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the jmp.LossScale API):

loss_scale = jmp.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1

Citing JMP

This repository is part of the DeepMind JAX Ecosystem, to cite JMP please use the DeepMind JAX Ecosystem citation.

References

[0] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740.

[1] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/.

Comments
  • Questions around speedup

    Questions around speedup

    Hi,

    Thanks for creating this amazing library!

    So from what I understood this is the minimal needed to benefit from the JMP speedup

    policy = jmp.Policy(param_dtype=jnp.float32, compute_dtype=jnp.float16, output_dtype=jnp.float16)
    ...
    # creating a network, and creating a loss_fn using the network
    data = policy.cast_to_compute(data)
    params = policy.cast_to_compute(params)
    
    grads = jax.grad(loss_fn)(data, params)
    ...
    grads = policy.cast_to_param(grads)
    

    The loss scale is only needed if we experience NaN or inf values. And thus we should see a difference in the time needed for a jax.grad operation when the data and params are in float16.

    Please correct me if I'm wrong or if I miss anything.

    Considering that, I have some questions:

    • Can we expect to see a 2x or any speedup if we time the jax.grad operation with (params, data) in float16 and in float32 or is the speedup only achieved at scale ?
    • Can we expect to see a 2x or any speedup with small networks and small experiments ? (say 2-layer nets and experiments that only need 2-3 minutes on a single gpu)
    • Can we expect to see a 2x or any speedup with all types of networks or only some with specific architectures ? (say a 50 layer MLP, Transformers, etc)
    • Is it necessary to apply the mixed precision policy to the network and thus to use Haiku for hk.mixed_precision.set_policy or is having the parameters and data in float16 sufficient to have a speedup even with a network created by us ?

    Thank you and have a nice day !

    opened by 1m1ne 3
  • Basic question about the use of my_policy

    Basic question about the use of my_policy

    Hi,

    Thanks for creating this amazing functionality!

    I have a basic question about the use of policy functions to set certain precision levels. As stated in the example of your README.md:

    def layer(params, x):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    I am new to JAX, but the first thing I learned is that JAX likes pure functions. Does the use of my_policy violate the pure functions paradigm?

    Should it become:

    def layer(params, x, my_policy):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    The function can be jitted by using partial.

    from functools import partial
    layer_compiled = jit(partial(layer, my_policy=my_policy))
    

    Fundamental question

    A more fundamental question (which I maybe need to ask at JAX ) is: How should functions as input to functions be handled in JAX?

    func_a(x,w,func):
      ...
    
    func_a_compiled = jit(partial(func_a, func=func_b))
    

    In order to jit this, I came up with the partial solution above. I assume your use of my_policy is valid, as you probably have more experience with JAX. But that creates some magic, which is undesirable for my use case. Is the jit(partial()) solution valid or is there a better way to handle functions as input to functions?

    Have a nice day! J

    opened by JSchuurmans 2
  •  jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    When trying to build my own wheel from the tar-ball of jmp-0.0.2 that has been uploaded to PyPI I get the following error:

    Collecting jmp==0.0.2
      Downloading jmp-0.0.2.tar.gz (13 kB)
        Running command python setup.py egg_info
        Traceback (most recent call last):
          File "<string>", line 1, in <module>
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 55, in <module>
            install_requires=_parse_requirements('requirements.txt'),
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 32, in _parse_requirements
            with open(requirements_txt_path) as fp:
        FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
    WARNING: Discarding https://files.pythonhosted.org/packages/7c/ba/a6bfcaeedca8551e2fb4054d1fd061a0dd97d26dd44002b3e92d13b51877/jmp-0.0.2.tar.gz#sha256=fdb5cec0d10aab4116c2770f24b2adf4f503fcfbb96ce8ef583e1879bdbf1b9b (from https://pypi.org/simple/jmp/). Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement jmp==0.0.2 (from versions: 0.0.1, 0.0.2)
    ERROR: No matching distribution found for jmp==0.0.2
    

    After downloading and extracting jmp-0.0.2.tar.gz from PyPI and looking at setup.py, I see that it contains lines that reference requirements.txt:

    install_requires=_parse_requirements('requirements.txt')
    

    however that file is not part of the source distribution and my build fails.

    Oliver

    P.S. Yes I know that there is a wheel for jmp-0.0.2 on PyPI, which I ended up using, but I'm using our Wheels_builder script that will recursively build wheels for dependencies for our systems. My point is that the source distribution of jmp 0.0.2 is incomplete as it lacks the information that is contained in the requirements file.

    opened by ostueker 1
  • Casting numpy array

    Casting numpy array

    I'll first start by saying that I've just starting with JAX, so I might be doing something wrong.

    When I run the following code:

    precision = jmp.get_policy('params=float32,compute=float16,output=float16')
    some_input = np.arange(15).reshape((5, 3))
    @jax.jit
    def some_function(some_input):
        some_input = precision.cast_to_compute(some_input)
        print(some_input.dtype)
        # prints float32
       return a_float16_compute_model(some_input) # fails
    

    It seems that at least on the first (tracing) run, the numpy array doesn't get cast to float16, maybe because it is being treated as a tree, here

    https://github.com/deepmind/jmp/blob/4b94370b8de29b79d6f840b09d1990b91c1afddd/jmp/_src/policy.py#L26

    Surprisingly, if I run

    some_input = np.arange(15).reshape((5, 3)).astype(precision.compute_dtype)
    print(some_input.dtype)
    
    #prints float16
    

    the cast succeeds. I think that the expected behavior is that precision.cast_to_compute(some_input) should return a numpy array with compute_dtype, but I might be missing something.

    opened by yardenas 1
  • [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    An upcoming change to JAX adds a @jit decorator around a number of array operators, in this case division (/). A side effect of that change is that large Python integer constants that overflow an int32 or int64 type may produce an error. The workaround is either to explicitly cast the large constants to a specific type (e.g., np.float64), or in this case we can just do the math in question in classic NumPy since it is computing test expectations.

    cla: yes 
    opened by copybara-service[bot] 0
  • Bump version to `0.0.3.dev`.

    Bump version to `0.0.3.dev`.

    Bump version to 0.0.3.dev.

    This is so if users install from GitHub we can see they are not using a stable version in bug reports:

    $ pip install git+https://github.com/deepmind/jmp
    $ python -c 'import jmp ; print(jmp.__version__)'
    0.0.3.dev
    
    cla: yes 
    opened by copybara-service[bot] 0
  • Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    We are not using 0.0.1 because this version was already used by the previous owner of the pypi package (https://pypi.org/project/jmp/0.0.1/). As such our releases will start at 0.0.2.

    cla: yes 
    opened by copybara-service[bot] 0
  • Replace jax.lax.select with jnp.where

    Replace jax.lax.select with jnp.where

    Thanks for the awesome work!

    This PR fixes an issue where jax.lax.select complains about dtypes not being equal when adjusting the DynamicLossScale.

    The exception's stack trace ends with:

    File ".../jmp/_src/loss_scale.py", line 147, in adjust
        loss_scale = jax.lax.select(
    TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
    

    My code looks similar to

    scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
    ...
    gradients, scale = gradient_fn(..., scale)
    gradients = scale.unscale(gradients)
    gradients_finite = jmp.all_finite(gradients)
    scale = scale.adjust(gradients_finite)  # This line throws the exception
    ...
    
    opened by nlsfnr 6
Releases(v0.0.2)
  • v0.0.2(Apr 15, 2021)

    Initial release of JMP.

    Changelog:

    • Add jmp.Policy abstraction and jmp.get_policy(..) factory.
    • Add jmp.LossScale and three implementations thereof (noop, static and dynamic).
    • Add various utilities (jmp.all_finite) to support common tasks in mixed precision codebases.
    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
An AutoML survey focusing on practical systems.

This project is a community effort in constructing and maintaining an up-to-date beginner-friendly introduction to AutoML, focusing on practical systems. AutoML is a big field, and continues to grow

AutoGOAL 16 Aug 14, 2022
Temporal Alignment Prediction for Supervised Representation Learning and Few-Shot Sequence Classification

Temporal Alignment Prediction for Supervised Representation Learning and Few-Shot Sequence Classification Introduction. This package includes the pyth

5 Dec 06, 2022
This is the material used in my free Persian course: Machine Learning with Python

This is the material used in my free Persian course: Machine Learning with Python

Yara Mohamadi 4 Aug 07, 2022
A python fast implementation of the famous SVD algorithm popularized by Simon Funk during Netflix Prize

⚡ funk-svd funk-svd is a Python 3 library implementing a fast version of the famous SVD algorithm popularized by Simon Funk during the Neflix Prize co

Geoffrey Bolmier 171 Dec 19, 2022
Optimal Randomized Canonical Correlation Analysis

ORCCA Optimal Randomized Canonical Correlation Analysis This project is for the python version of ORCCA algorithm. It depends on Numpy for matrix calc

Yinsong Wang 1 Nov 21, 2021
A collection of interactive machine-learning experiments: 🏋️models training + 🎨models demo

🤖 Interactive Machine Learning experiments: 🏋️models training + 🎨models demo

Oleksii Trekhleb 1.4k Jan 06, 2023
Python module for data science and machine learning users.

dsnk-distributions package dsnk distribution is a Python module for data science and machine learning that was created with the goal of reducing calcu

Emmanuel ASIFIWE 1 Nov 23, 2021
💀mummify: a version control tool for machine learning

mummify is a version control tool for machine learning. It's simple, fast, and designed for model prototyping.

Max Humber 43 Jul 09, 2022
Real-time stream processing for python

Streamz Streamz helps you build pipelines to manage continuous streams of data. It is simple to use in simple cases, but also supports complex pipelin

Python Streamz 1.1k Dec 28, 2022
Python module for machine learning time series:

seglearn Seglearn is a python package for machine learning time series or sequences. It provides an integrated pipeline for segmentation, feature extr

David Burns 536 Dec 29, 2022
AutoOED: Automated Optimal Experiment Design Platform

AutoOED is an optimal experiment design platform powered with automated machine learning to accelerate the discovery of optimal solutions. Our platform solves multi-objective optimization problems an

Yunsheng Tian 107 Jan 03, 2023
High performance, easy-to-use, and scalable machine learning (ML) package, including linear model (LR), factorization machines (FM), and field-aware factorization machines (FFM) for Python and CLI interface.

What is xLearn? xLearn is a high performance, easy-to-use, and scalable machine learning package that contains linear model (LR), factorization machin

Chao Ma 3k Jan 08, 2023
learn python in 100 days, a simple step could be follow from beginner to master of every aspect of python programming and project also include side project which you can use as demo project for your personal portfolio

learn python in 100 days, a simple step could be follow from beginner to master of every aspect of python programming and project also include side project which you can use as demo project for your

BDFD 6 Nov 05, 2022
Repository for DCA0305, an undergraduate course about Machine Learning Workflows and Pipelines

Federal University of Rio Grande do Norte Technology Center Department of Computer Engineering and Automation Machine Learning Based Systems Design Re

Ivanovitch Silva 81 Oct 18, 2022
InfiniteBoost: building infinite ensembles with gradient descent

InfiniteBoost Code for a paper InfiniteBoost: building infinite ensembles with gradient descent (arXiv:1706.01109). A. Rogozhnikov, T. Likhomanenko De

Alex Rogozhnikov 183 Jan 03, 2023
Dual Adaptive Sampling for Machine Learning Interatomic potential.

DAS Dual Adaptive Sampling for Machine Learning Interatomic potential. How to cite If you use this code in your research, please cite this using: Hong

6 Jul 06, 2022
A Python Package to Tackle the Curse of Imbalanced Datasets in Machine Learning

imbalanced-learn imbalanced-learn is a python package offering a number of re-sampling techniques commonly used in datasets showing strong between-cla

6.2k Jan 01, 2023
Mesh TensorFlow: Model Parallelism Made Easier

Mesh TensorFlow - Model Parallelism Made Easier Introduction Mesh TensorFlow (mtf) is a language for distributed deep learning, capable of specifying

1.3k Dec 26, 2022
MIT-Machine Learning with Python–From Linear Models to Deep Learning

MIT-Machine Learning with Python–From Linear Models to Deep Learning | One of the 5 courses in MIT MicroMasters in Statistics & Data Science Welcome t

2 Aug 23, 2022
LibTraffic is a unified, flexible and comprehensive traffic prediction library based on PyTorch

LibTraffic is a unified, flexible and comprehensive traffic prediction library, which provides researchers with a credibly experimental tool and a convenient development framework. Our library is imp

432 Jan 05, 2023