Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Overview

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a Module. We will use Flax Linen for this example:

import flax.linen as nn
import jax

class MLP(nn.Module):
    @nn.compact
    def call(self, x):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import elegy, optax

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")
        
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples first install some required packages:

pip install -r examples/requirements.txt

Now run the example:

python examples/flax_mnist_vae.py 

Contributing

Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

Apache

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.5.0},
year = {2020},
}

Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.

Comments
  • Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    As noted below, this PR contains the following features:

    • It turns Elegy into a framework agnostic library by removing the dependencies between elegy.Model and elegy.Module, it proposes the GeneralizedModule API and implements it for Flax, Haiku, Elegy Module types, and regular python functions.
    • It introduces a new low-level API similar to Pytorch Lightning that lets users manually override the core parts of the training loop when maximal flexibility is required.
    • General changes that enable the framework-agnostic mindset.
    • Many quality of life changes like standardization of hooks, simplification of the Module system, etc.

    Tasks:

    • [x] Create hooks module
    • [x] Refactor Model with low-level API and remove Module dependencies
    • [x] Refactor Module to use hooks
    • [x] Create GeneralizedModule and GeneralizedOptimizer Inferfaces
    • [x] Implement GeneralizedModule for flax.linen.Module
    • [x] Implement GeneralizedModule for elegy.Module
    • [x] Implement GeneralizedModule for haiku.Module
    • [x] Implement GeneralizedOptimizer for optax.GradientTransformation
    • [x] Implement GeneralizedOptimizer for elegy.Optimizer
    • [x] Fix Model.summary
    • [x] Fix tests
    • [x] Fix examples
    • [ ] Fix README
    • [ ] Fix guides
    • [ ] Fix docstrings
    opened by cgarciae 27
  • WGAN-GP low-level API example

    WGAN-GP low-level API example

    A more extensive example using the new low-level API: Wasserstein-GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset.

    Some good generated images: epoch-0079 epoch-0084 epoch-0089

    Some notes:

    • I first tried to train a DCGAN which uses binary crossentropy but I've run into balancing issues. The discriminator quickly becomes too good so that the generator does not learn anything. The same model implemented in PyTorch or TensorFlow works. Most modern GANs don't use the WGAN loss anymore, most use BCE.
    • I'm still in favor of making Module.apply() return init(). It's just too much boilerplate to use an if-else every time. I avoided it by manually calling wgan.states = wgan.init(...) after model instantiation which I think is also not nice.
    • Can we make Module.apply() accept params and states separately instead of collections. It's annoying having to construct a dict {'params':params, 'states':states} every time
    • It would be nice if elegy.States was a dict so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like (g_states, d_states) = net_states which is again annoying
    • Model.save() fails on this model. Partially due to the extra jitted functions but even when I remove them, cloudpickle chokes on _HooksContext

    @cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?

    opened by alexander-g 11
  • Add learning rate logging

    Add learning rate logging

    Implements the same functionality from #131 using only minor modifications to elegy.Optimizer.

    • [x] Add lr_schedule and steps_per_epoch to Optimizer.
    • [x] Implement Optimizer.get_effective_learning_rate
    • [x] Copy logging code from #131
    • [x] Add documentation

    @alexander-g Here is a proposal that is a bit simpler, closer to what I mentioned in #124. What do you think? @charlielito should we log the learning rate automatically if available or should we create a Callback?

    opened by cgarciae 9
  • Question: how to set the random state when calling model.predict(...)

    Question: how to set the random state when calling model.predict(...)

    Not sure if this is the right place to post this...

    I have built and trained a VAE. When calling model.predict(x=test_set), I would like to make multiple predictions for each item in the test set (because VAE's are probabilistic). That way I can look at the distribution of predictions for each item in the test_set.

    The call() for the VAE includes the line
    intrinsic_latents = mean + stds * jax.random.normal(self.next_key(), mean.shape).

    I haven't been able to find an explanation for how self.next_key() works or how to change the random seed on each call so that I can get different predictions. I could rewrite the code so that random seeds are explicitly passed, but I assume there is some functionality build into elegy to make this easy?

    Could someone explain how this works, or point me to the documentation explaining it?

    Thanks!

    opened by jfcrenshaw 8
  • Examples Cleanup

    Examples Cleanup

    • refactored examples/imagenet/resnet_imagenet.py to accept parameters instead of modifying them inside the script
    • added README.md for examples/imagenet/
    • removed unnecessary Lambda class from examples/mnist.py
    • moved global average pooling in examples/mnist_conv.py before the Linear layer
    opened by alexander-g 7
  • Resnet

    Resnet

    • ResNet model architecture and an example for training on ImageNet
      • code is mostly adapted from the flax library
      • pretrained ResNet50 with 76.5% accuracy
      • pretrained ResNet18 with 68.7% accuracy
    • Experimental support for mixed precision: previously all layers set their parameters' dtype to the input's dtype. This is incorrect, for numerical stability reasons all parameters should be float32 even when performing float16 computations. See more here.
    • Some issues I had during training:
      • There seems to be a memory leak during training, RAM constantly increased
      • I had to use smaller batch sizes than when training with flax or with TensorFlow before maxing out GPU memory (64 instead of 128 for ResNet50 on a RTX2080Ti). This might be of course due to a mistake in my code, but the number of parameters is identical to the flax and PyTorch versions, so I think it might be somewhere else
    opened by alexander-g 7
  • [Bug] Problem with computing metrics

    [Bug] Problem with computing metrics

    Describe the bug Hi, when I am using the fit function I have an error message that the update function is not provided with y_true and y_pred. It seems to be coming from the metrics of the model, because if I comment the metrics line I have no error

    TypeError: update() missing 2 required positional arguments: 'y_true' and 'y_pred'
    

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import jax
    import jax.numpy as jnp
    import ml_collections
    import numpy as np
    import optax
    import elegy as eg
    
    
    class eCNN(eg.Module):
        """A simple CNN model."""
    
        @eg.compact
        def __call__(self, x):
            x=eg.Conv(10,kernel_size=(10,))(x)
            x=jax.nn.relu(x)
            x = eg.Linear(1)(x)
            x=jax.nn.sigmoid(x)
            return x
    
    n=200
    X_train = np.random.rand(n*100).reshape(n,100)
    y_train = np.random.rand(n).reshape(n,1)
    print(X_train.shape)
    print(y_train.shape)
    
    model = eg.Model(
        module=eCNN(),
        loss=[
            eg.losses.MeanSquaredError(),
        ],
        metrics=eg.metrics.MeanSquareError(),  #Line to be commented to get rid of the error
        optimizer=optax.rmsprop(1e-3),
    )
    
    model.fit(X_train,y_train,
        epochs=10,
        batch_size=20,
        #validation_data=0.1,
        shuffle=False,
        callbacks=[eg.callbacks.TensorBoard("summaries")]
        )
    

    Library Info Please provide os info and elegy version.

    import elegy
    print(elegy.__version__) 
    # 0.8.4
    
    bug 
    opened by organic-chemistry 6
  • Multi-gpu with pmap docs

    Multi-gpu with pmap docs

    One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?

    opened by sooheon 6
  • SCCE fix for bug in Jax<0.2.7

    SCCE fix for bug in Jax<0.2.7

    Small fix for a bug in Jax<0.2.7 where jax.lax.take_along_axis gives incorrect results for uint8 indices. Very relevant for semantic segmentation.

    Alternatively consider updating Jax

    opened by alexander-g 6
  • Dataset & DataLoader

    Dataset & DataLoader

    Dataset and parallel DataLoader API similar to PyTorch. Can be used with Model.fit()

    class MyDataset(elegy.data.Dataset):
        def __len__(self):
            return 128
    
        def __getitem__(self, i):
            #dummy data
            return np.random.random([224, 224, 3]),  np.random.randint(10)
    
    ds     = MyDataset()
    loader = elegy.data.DataLoader(ds, batch_size=8, n_workers=8, worker_type='thread', shuffle=True)
    
    batch = next(iter(loader))
    assert batch[0].shape == (8,224,224,3)
    assert batch[1].shape == (8,)
    assert len(loader) == 16
    
    model.fit(loader, epochs=10)
    
    opened by alexander-g 6
  • Implemented BinaryCrossentropy metric

    Implemented BinaryCrossentropy metric

    Updates:

    • Created BinaryCrossentropy metric
    • Created basic tests for BinaryCrossentropy metric (passing)
    • Created docs for BinaryCrossentropy metric
    • Refactored main docs by balancing files and correcting language typos
    documentation 
    opened by sebasarango1180 6
  • use poetry-core

    use poetry-core

    poetry-core is intended to be a light weight, fully compliant, self-contained package allowing PEP 517 compatible build frontends to build Poetry managed projects.

    Using poetry-core allows distribution packages to depend only on the build backend.

    opened by dotlambda 0
  • Documentation/API reference not accessible via project website[Bug]

    Documentation/API reference not accessible via project website[Bug]

    Hi, It looks like I can't really access the API reference for Elegy. The corresponding link on the project's website simply takes me back to the main page (https://poets-ai.github.io/elegy/). Any idea what's up?

    bug 
    opened by geomlyd 0
  • [Bug] elegy does not work with latest haiku version

    [Bug] elegy does not work with latest haiku version

    Describe the bug When I type 'import elegy' I get this error

     File "/home/kpmurphy/mambaforge/lib/python3.10/site-packages/elegy/generalized_module/haiku_module.py", line 4, in <module>
        from haiku._src.base import current_bundle_name
    

    Minimal code to reproduce

    import elegy
    

    Expected behavior A clear and concise description of what you expected to happen.

    Library Info Please provide os info and elegy version.

    >> 
    >>> jax.__version__
    '0.2.28'
    >>> haiku.__version__
    '0.0.9.dev'
    >>> elegy.__version__. #  elegy-0.5.0-py3-none-any.whl 
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    NameError: name 'elegy' is not defined
    >>> 
    

    Screenshots

    Screen Shot 2022-10-03 at 2 33 21 PM

    Additional context Add any other context about the problem here.

    bug 
    opened by murphyk 5
  • CSVLogger iteration over a 0-d array

    CSVLogger iteration over a 0-d array

    Describe the bug When using the CSVLogger callback, elegy crashes at the end of the first epoch.

    Minimal code to reproduce

    import elegy as eg
    import optax
    import numpy as np
    
    x = np.random.randn(64, 1)
    y = np.random.randn(64, 1)
    
    model = eg.Model(
        eg.nn.Linear(1),
        loss=eg.losses.MeanSquaredError(),
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        x,
        y,
        epochs=10,
        callbacks=[
            eg.callbacks.CSVLogger("train.csv"), <-- commenting
        ]
    )
    

    Stack trace:

    Epoch 1/10
    2/2 [==============================] - ETA: 0s - loss: 1.3408 - mean_squared_error_loss: 1.3408
    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/csv.py", line 14, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 465, in fit
        callbacks.on_epoch_end(epoch, epoch_logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/callback_list.py", line 221, in on_epoch_end
        callback.on_epoch_end(epoch, logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in on_epoch_end
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in <genexpr>
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 68, in handle_value
        return '"[%s]"' % (", ".join(map(str, k)))
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py", line 245, in __iter__
        raise TypeError("iteration over a 0-d array")  # same as numpy error
    TypeError: iteration over a 0-d array
    

    Expected behavior Should treat 0-d array as scalar.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context More detailed error information shows the error occurs because the array is a jax DeviceArray and so the test for zero dimensional array uses the line

    is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
    
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py:6 │
    │ 8 in handle_value                                                                                │
    │                                                                                                  │
    │    65 │   │   │   if isinstance(k, six.string_types):                                            │
    │    66 │   │   │   │   return k                                                                   │
    │    67 │   │   │   elif isinstance(k, tp.Iterable) and not is_zero_dim_ndarray:                   │
    │ ❱  68 │   │   │   │   return '"[%s]"' % (", ".join(map(str, k)))                                 │
    │    69 │   │   │   else:                                                                          │
    │    70 │   │   │   │   return k                                                                   │
    │    71                                                                                            │
    │                                                                                                  │
    │ ╭──────────────────────────── locals ─────────────────────────────╮                              │
    │ │ is_zero_dim_ndarray = False                                     │                              │
    │ │                   k = DeviceArray(4.8264385e-05, dtype=float32) │                              │
    │ ╰─────────────────────────────────────────────────────────────────╯                              │
    │                                                                                                  │
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py:245 in │
    │ __iter__                                                                                         │
    │                                                                                                  │
    │   242                                                                                            │
    │   243   def __iter__(self):                                                                      │
    │   244 │   if self.ndim == 0:                                                                     │
    │ ❱ 245 │     raise TypeError("iteration over a 0-d array")  # same as numpy error                 │
    │   246 │   else:                                                                                  │
    │   247 │     return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())            │
    │   248                                                                                            │
    │                                                                                                  │
    │ ╭───────────────────── locals ─────────────────────╮                                             │
    │ │ self = DeviceArray(4.8264385e-05, dtype=float32) │                                             │
    │ ╰──────────────────────────────────────────────────╯                                             │
    ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
    TypeError: iteration over a 0-d array
    
    bug 
    opened by ScottAlexanderCameron 0
  • Metrics ignore

    Metrics ignore "on" keyword arg

    Describe the bug I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the on keyword argument. This works fine for the loss functions but not for metrics.

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import elegy as eg
    import optax
    import numpy as np
    
    
    def data_generator():
        while True:
            yield (
                np.random.randn(10, 1),
                {"target": {"y": np.random.randn(10, 1)}},
            )
    
    
    class MyModule(eg.Module):
        @eg.compact
        def __call__(self, x):
            return {"y": eg.nn.Linear(1)(x)}
    
    
    model = eg.Model(
        MyModule(),
        loss=eg.losses.MeanSquaredError(on="y"),
        metrics=eg.metrics.MeanAbsoluteError(on="y"),  #  <-- works fine without this line
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        data_generator(),
        steps_per_epoch=10,
        epochs=10,
    )
    

    Stack trace:

    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/metric.py", line 27, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 417, in fit
        tmp_logs = self.train_on_batch(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 617, in train_on_batch
        logs, model = train_step_fn(self, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 412, in _static_train_step
        return model.train_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 306, in train_step
        grads, (logs, model) = grad_fn(params, model, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 278, in loss_fn
        loss, logs, model = model.test_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 248, in test_step
        batch_loss_and_logs.update(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/loss_and_logs.py", line 78, in update
        self.metrics.update(**metrics_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/metrics.py", line 44, in update
        metric.update(**metric_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 83, in update
        values = _mean_absolute_error(preds, target)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 20, in _mean_absolute_error
        target = target.astype(preds.dtype)
    AttributeError: 'dict' object has no attribute 'astype'
    

    Expected behavior Should produce the same result as if the dictionaries are removed and the on arg not specified.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context From my digging the cause seems to be due to the Metric.update() method being called instead of the __call__ method.

    bug 
    opened by ScottAlexanderCameron 0
  • [Bug] Elegy crash on GPU

    [Bug] Elegy crash on GPU

    Describe the bug

    Running mnist_cnn.py in the example dir crash the instance at the end of the first epoch.

    This was previously reported on Colab GPU instance. But I can reproduce this on CLI too.

    Running on CPU does not have this problem.

    Running on eager mode with GPU does not have this problem.

    Minimal code to reproduce

    python mnist_cnn.py
    

    Expected behavior Not stuck.

    Library Info CentOS Linux release 7.6.1810 elegy 0.8.6

    Additional context absl-py==1.2.0 aiohttp==3.8.1 aiosignal==1.2.0 async-timeout==4.0.2 attrs==22.1.0 certifi==2021.10.8 charset-normalizer==2.1.1 chex==0.1.4 click==8.1.3 cloudpickle==1.6.0 colorama==0.4.5 commonmark==0.9.1 cycler==0.11.0 datasets==2.4.0 dill==0.3.5.1 dm-tree==0.1.7 docker-pycreds==0.4.0 einops==0.4.1 elegy==0.8.6 etils==0.7.1 filelock==3.8.0 flax==0.4.2 fonttools==4.36.0 frozenlist==1.3.1 fsspec==2022.7.1 gitdb==4.0.9 GitPython==3.1.27 h5py==3.6.0 huggingface-hub==0.8.1 idna==3.3 importlib-resources==5.9.0 jax==0.3.16 jaxlib==0.3.15+cuda11.cudnn82 kiwisolver==1.4.4 matplotlib==3.5.3 msgpack==1.0.4 multidict==6.0.2 multiprocess==0.70.13 numpy==1.22.3 opt-einsum==3.3.0 optax==0.1.3 packaging==21.3 pandas==1.4.3 pathtools==0.1.2 Pillow==9.2.0 promise==2.3 protobuf==3.20.1 psutil==5.9.1 pyarrow==9.0.0 Pygments==2.13.0 pyparsing==3.0.9 python-dateutil==2.8.2 pytz==2022.2.1 PyYAML==6.0 requests==2.28.1 responses==0.18.0 rich==11.2.0 scipy==1.8.0 sentry-sdk==1.9.5 setproctitle==1.3.2 shortuuid==1.0.9 six==1.16.0 smmap==5.0.0 tensorboardX==2.5.1 toolz==0.12.0 tqdm==4.64.0 treeo==0.0.10 treex==0.6.10 typing_extensions==4.3.0 urllib3==1.26.11 wandb==0.12.21 xxhash==3.0.0 yarl==1.8.1 zipp==3.8.1

    bug 
    opened by jiyuuchc 2
Releases(0.8.6)
  • 0.8.6(Mar 23, 2022)

    🚀 Features

    • Weights and Biases Callback for Elegy
      • PR: #220

    🐛 Fixes

    • Docs typos
      • PR: #222
    • Donate model's memory buffer to jit/pmap functions.
      • PR: #226
    • Lazy load wandb
      • PR: #228
    Source code(tar.gz)
    Source code(zip)
  • 0.8.5(Feb 23, 2022)

  • 0.8.4(Dec 14, 2021)

  • 0.8.3(Dec 13, 2021)

  • 0.8.2(Dec 13, 2021)

  • 0.8.1(Nov 8, 2021)

    Elegy is now based on Treex 🎉

    Changes

    • Remove the module, nn, metrics, and losses from Elegy, instead Elegy reexports these modules from Treex.
    • GeneralizedModule and friends are gone, to use Flax Modules use the elegy.nn.FlaxModule wrapper.
    • Low level API is massively simplified:
      • States is removed, since Model is a pytree all parameters are tracked automatically thanks to Treex / Treeo.
      • All static state arguments (training, initializing) are removed, Modules can simply use self.training to pick their training state and self.initializing() to check whether they are initializing.
      • Signature for pred_step, test_step, and train_step now simply consists of inputs and labels, where labels is a dict that can contain additional keys like sample_weight or class_weight as required by the losses and metrics.
    • Adds the DistributedStrategy class which currently has 3 instances
      • Eager: Runs model in a single device in eager mode (no jit)
      • JIT: Runs model in a single device with jit
      • DataParallel: Run the model in multiple devices using pmap.
    • Adds methods to change the model's distributed strategy:
      • .distributed(strategy = DataParallel): changes the distributed strategy, DataParallel used by default.
      • .local(): changes the distributed strategy to JIT.
      • .eager(): changes the distributed strategy to Eager.
    • Removes the .eager field in favor of the .eager() method.
    Source code(tar.gz)
    Source code(zip)
  • 0.7.4(Jun 1, 2021)

  • 0.7.2(Mar 10, 2021)

  • 0.7.1(Mar 1, 2021)

  • 0.7.0(Feb 22, 2021)

    Features

    • init now only called once internally and required to be called explicitly by the user under certain circumstances
    • summary now uses jax.eval_shape under the hood so its super fast since it doesn't allocate memory or perform any computations on the device.

    Merged pull requests:

    • Fix notebook #166 (cgarciae)
    • Single Initialization: Removes the current progressive initialization in favor of a single underlying call to init_step. #165 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.6.0(Feb 14, 2021)

  • 0.5.0(Feb 8, 2021)

    This version simplifies parts of the low-level API in spirit of what was introduced in 0.4.0 to provide a more homogeneous and simpler experience.

    Merged pull requests:

    • Improve States: uses __dict__ so States works with vars #159 (cgarciae)
    • Simplify API: Cleans-up some API details around Model and Module #158 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.4.1(Feb 3, 2021)

  • 0.4.0(Feb 1, 2021)

    Implemented enhancements:

    • [Feature Request] Monitoring learning rates #124

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Dec 17, 2020)

    Implemented enhancements:

    • elegy.nn.Sequential docs not clear #107
    • [Feature Request] Community example repo. #98

    Fixed bugs:

    • [Bug] Accuracy from Model.evaluate() is inconsistent with manually computed accuracy #109
    • Exceptions in "Getting Started" colab notebook #104

    Closed issues:

    • l2_normalize #102
    • Need some help for contributing new losses. #93
    • Document Sum #62
    • Binary Accuracy Metric #58
    • Automate generation of API Reference folder structure #19
    • Implement Model.summary #3

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Aug 31, 2020)

  • 0.2.1(Aug 25, 2020)

  • 0.2.0(Aug 17, 2020)

  • 0.1.5(Jul 28, 2020)

    • Mean Absolute Percentage Error Implementation @Ciroye
    • Adds elegy.nn.Linear, elegy.nn.Conv2D, elegy.nn.Flatten, elegy.nn.Sequential @cgarciae
    • Add Elegy hooks @cgarciae
    • Improves Tensorboard support @Davidnet
    • Added coverage metrics to CI @charlielito
    Source code(tar.gz)
    Source code(zip)
  • 0.1.4(Jul 24, 2020)

    • Adds elegy.metrics.BinaryCrossentropy @sebasarango1180
    • Adds elegy.nn.Dropout and elegy.nn.BatchNormalization @cgarciae
    • Improves documentation
    • Fixes bug that cause error when using is_training via dependency injection on Model.predict.
    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jul 23, 2020)

Official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space

NeuralFusion This is the official implementation of NeuralFusion: Online Depth Map Fusion in Latent Space. We provide code to train the proposed pipel

53 Jan 01, 2023
This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detection', CVPR 2019.

Code-and-Dataset-for-CapSal This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detec

lu zhang 48 Aug 19, 2022
This is an open source library implementing hyperbox-based machine learning algorithms

hyperbox-brain is a Python open source toolbox implementing hyperbox-based machine learning algorithms built on top of scikit-learn and is distributed

Complex Adaptive Systems (CAS) Lab - University of Technology Sydney 21 Dec 14, 2022
GNPy: Optical Route Planning and DWDM Network Optimization

GNPy is an open-source, community-developed library for building route planning and optimization tools in real-world mesh optical networks

Telecom Infra Project 140 Dec 19, 2022
Gradient Inversion with Generative Image Prior

Gradient Inversion with Generative Image Prior This repository is an implementation of "Gradient Inversion with Generative Image Prior", accepted to N

MLLab @ Postech 25 Jan 09, 2023
It's like Shape Editor in Maya but works with skeletons (transforms).

Skeleposer What is Skeleposer? Briefly, it's like Shape Editor in Maya, but works with transforms and joints. It can be used to make complex facial ri

Alexander Zagoruyko 1 Nov 11, 2022
you can add any codes in any language by creating its respective folder (if already not available).

HACKTOBERFEST-2021-WEB-DEV Beginner-Hacktoberfest Need Your first pr for hacktoberfest 2k21 ? come on in About This is repository of Responsive Portfo

Suman Sharma 8 Oct 17, 2022
Impelmentation for paper Feature Generation and Hypothesis Verification for Reliable Face Anti-Spoofing

FGHV Impelmentation for paper Feature Generation and Hypothesis Verification for Reliable Face Anti-Spoofing Requirements Python 3.6 Pytorch 1.5.0 Cud

5 Jun 02, 2022
The Pytorch implementation for "Video-Text Pre-training with Learned Regions"

Region_Learner The Pytorch implementation for "Video-Text Pre-training with Learned Regions" (arxiv) We are still cleaning up the code further and pre

Rui Yan 0 Mar 20, 2022
The dataset and source code for our paper: "Did You Ask a Good Question? A Cross-Domain Question IntentionClassification Benchmark for Text-to-SQL"

TriageSQL The dataset and source code for our paper: "Did You Ask a Good Question? A Cross-Domain Question Intention Classification Benchmark for Text

Yusen Zhang 22 Nov 09, 2022
Minecraft Hack Detection With Python

Minecraft Hack Detection An attempt to try and use crowd sourced replays to find

Kuleen Sasse 3 Mar 26, 2022
Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face Manipulation" published in CVPR 2020.

FFD Source Code Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face M

88 Nov 22, 2022
The Official Implementation of the ICCV-2021 Paper: Semantically Coherent Out-of-Distribution Detection.

SCOOD-UDG (ICCV 2021) This repository is the official implementation of the paper: Semantically Coherent Out-of-Distribution Detection Jingkang Yang,

Jake YANG 62 Nov 21, 2022
[ECCV 2020] Gradient-Induced Co-Saliency Detection

Gradient-Induced Co-Saliency Detection Zhao Zhang*, Wenda Jin*, Jun Xu, Ming-Ming Cheng ⭐ Project Home » The official repo of the ECCV 2020 paper Grad

Zhao Zhang 35 Nov 25, 2022
Multiple Object Extraction from Aerial Imagery with Convolutional Neural Networks

This is an implementation of Volodymyr Mnih's dissertation methods on his Massachusetts road & building dataset and my original methods that are publi

Shunta Saito 255 Sep 07, 2022
[CVPR2021 Oral] UP-DETR: Unsupervised Pre-training for Object Detection with Transformers

UP-DETR: Unsupervised Pre-training for Object Detection with Transformers This is the official PyTorch implementation and models for UP-DETR paper: @a

dddzg 430 Dec 23, 2022
NeurIPS 2021, self-supervised 6D pose on category level

SE(3)-eSCOPE video | paper | website Leveraging SE(3) Equivariance for Self-Supervised Category-Level Object Pose Estimation Xiaolong Li, Yijia Weng,

Xiaolong 63 Nov 22, 2022
MACE is a deep learning inference framework optimized for mobile heterogeneous computing platforms.

Documentation | FAQ | Release Notes | Roadmap | MACE Model Zoo | Demo | Join Us | 中文 Mobile AI Compute Engine (or MACE for short) is a deep learning i

Xiaomi 4.7k Dec 29, 2022
A Python library for Deep Probabilistic Modeling

Abstract DeeProb-kit is a Python library that implements deep probabilistic models such as various kinds of Sum-Product Networks, Normalizing Flows an

DeeProb-org 46 Dec 26, 2022
Punctuation Restoration using Transformer Models for High-and Low-Resource Languages

Punctuation Restoration using Transformer Models This repository contins official implementation of the paper Punctuation Restoration using Transforme

Tanvirul Alam 142 Jan 01, 2023