Hardware accelerated, batchable and differentiable optimizers in JAX.

Overview

JAXopt

Installation | Examples | References

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

Installation

JAXopt can be installed with pip directly from github, with the following command:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be be installed from sources with the following command:

$ python setup.py install

References

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

Comments
  • Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat optimizer runs exceptionally slow (~30 seconds for 30 iterations) until I turn on verbose==True (~1 second for 30 iterations. Any idea what may be going on? enabling JIT seems to have no impact. Was hoping to use this for a real-time system but even at 1 second things are way too slow.

    opened by pablovela5620 23
  • implementation of Fletcher-Reeves Algorithm

    implementation of Fletcher-Reeves Algorithm

    • Polak-Ribiere Method; To my knowledge, it was quite successful to use conjugate gradient variants on general nonconstrained optimization

    This PR depends on Line Search of PR #128.

    • Beta division is required to guarantee strong Wolfe Condition, but (i don't know) it raises error..
    pull ready 
    opened by ita9naiwa 17
  • vmap support in QPs

    vmap support in QPs

    Hi, I experience some pb with projection_polyhedron

    import numpy as np
    import matplotlib.pyplot as plt
    
    import jax
    import jax.numpy as jnp
    
    import jaxopt
    from jaxopt.projection import projection_l2_ball, projection_box, projection_l1_ball, projection_polyhedron
    
    def myproj3(x):
        A = jnp.array([[1.0, 1.0]])
        b = jnp.array([1.0])
        G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
        h = jnp.array([0.0, 0.0])    
        x = projection_polyhedron(x,hyperparams = (A, b, G, h))
        return x
    
    rng_key = jax.random.PRNGKey(42)
    x = jax.random.uniform(rng_key, (5000,2), minval=-3,maxval=3)
    p1_x=jax.vmap(myproj3, in_axes=(0,None))(x)
    fig, ax = plt.subplots(figsize=(5,5))
    ax.scatter(x[:,0],x[:,1],s=0.5)
    ax.scatter(p1_x[:,0],p1_x[:,1],s=0.5,c='g')
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    plot.show()
    

    First, I had to install cvxpy #!pip install cvxpy Then, I got this error

    TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2])>with<BatchTrace(level=1/1)>
      with val = DeviceArray([[-2.37103211,  2.33759997],
                              [ 2.76953806, -2.37750394],
                              [-0.87246632,  0.73224625],
                              ...,
                              [ 2.29799773,  2.81894884],
                              [ 2.4022714 ,  0.80693103],
                              [-0.41563116,  2.83898531]], dtype=float64)
           batch_dim = 0
    

    Is anyone has an hint? Thanks

    enhancement 
    opened by jecampagne 12
  • KKT conditions when the primal solution is a pytree

    KKT conditions when the primal solution is a pytree

    Hi, Congrats on the great tool! Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:

    TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))

    where I'm pretty sure [(*, *), (), (*, *)] represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None)) could represent the optimality function.

    I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).

    To make it easier to reproduce I modified the quadratic_prog.py file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun, eq_fun and ineq_fun to use primal_var[0] instead of primal_var. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is: ([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)

    Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.

    Notice also thatthe test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).

    Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the quadratic_prog.py example?

    Thanks!

    opened by FerranAlet 11
  • Hot fix: corrected condition in lbfgs

    Hot fix: corrected condition in lbfgs

    The feature I had introduced in https://github.com/google/jaxopt/pull/323 was failing when the run function was jitted and was a no-op when not because of the following reason:

     ~True == -2  # this is True
    

    Therefore when jitted it was complaining about different types in a condition function, and when not jitted it was equivalent to always being False.

    EDIT

    Actually I am still running into an error when jitted, so will continue to investigate.

    The gist of the error is Abstract tracer value encountered where concrete value is expected, basically doing (not self.stop_if_linesearch_fails | ~state.failed_linesearch) is not allowed because one is a bool and the other is an abstract value.

    pull ready 
    opened by zaccharieramzi 7
  • Issue with gradients wrt optimality fn parameters through root finding vjp

    Issue with gradients wrt optimality fn parameters through root finding vjp

    First of all, thanks a lot for this library! Really useful tools! I'm interested in getting at least 2nd order gradients through root finding, and I'm finding an odd behavior that I wanted to report.

    Maybe I'm doing something wrong, but in the following schematic case I silently get the wrong gradients:

    def inv_f(x, aux):
      bisec = Bisection(optimality_fun=F, lower=0.0, upper=1., 
                        check_bracket=False, unroll=True)
      return bisec.run(aux=aux).params
    
    # Here I extract the value part of the vjp, but the grad part also gives wrong results
    test_fn = lambda aux: jax.value_and_grad(inv_f)(0.5, aux)[0] 
    
    jax.grad(test_fn)(1.) # Returns 0 instead of the expected gradients
    

    Here I'm only trying to get gradients of the value returned by jax.value_and_grad, but the gradients of the gradients returned by jax.value_and_grad are also wrong (but not as obvious).

    I made a small demo notebook that reproduces this issue here.

    As a reference I've also implemented my own implicit gradients, bypassing the jaxopt ones, and they seem to give me the correct answer.

    Reading the source code of jaxopt, it is not immediatly obvious to me why this doesn't work... Sorry I couldn't directly suggest a PR, but I hope this report is still useful (and that I'm not just using jaxopt wrong).

    bug 
    opened by EiffL 7
  • misc improvements to robust training example

    misc improvements to robust training example

    main changes:

    • Fixes #134 by normalizing in-place.
    • Plot convergence curves for both clean and adversarial accuracy.
    • Replace the fast-sign-gradient method by the much more powerful PGD method.
    • Be able to select different datasets.
    • Homogeneize API wrt to the other examples. For example, this now uses the same load_dataset, CNN, loss_fun, accuracy than flax_image_classif.py . Most of the command line flags have also been homogeneized.
    pull ready 
    opened by fabianp 7
  • Bisection hanging

    Bisection hanging

    I am trying to jaxopt.Bisection to replace the use of scipy.optimize.bisect in a computational model but Bisection hangs when I run my code.

    The basic structure includes 2 functions that are both jitted (so I assume it should be able to compile ok):

    @jit
    def f1(parameters):
        ....
        return jax.numpy.array([a,b,c])
    
    @jit
    def opt_fun(x):
        f1(x,params)
        .... 
        return float_value
    

    when I call scipy.optimize.bisect(opt_fun,x0,x1) it runs with no issue but jaxopt.Bisection(opt_fun,x0,x1).run(None) hangs with with~10% cpu usage and55% memory usage on i9 2018 macbook pro with 32GB of memory.

    I acknowledge I may be using this incorrectly and that this is possibly not the intended use case but any direction would be very helpful. My intention is to use this computational model with numpyro in the future and having a jax version of the bisection root finding would be incredibly helpful.

    opened by jjruby09 7
  • Incompatible shape in solve_normal_cg

    Incompatible shape in solve_normal_cg

    When A.shape = (N, P) for N != P, I run into shape errors when trying to use solve_normal_cg for fitting the normal equations.

    I have a small reproducible example below for N > P, but the error holds for when P > N.

    import jax.numpy as jnp
    import numpy as np
    N = 1000
    P = 3
    prob = np.random.uniform(0.01, 0.5, size=P)
    h2g = 0.1
    X = np.random.binomial(2, p=prob, size=(N, P))
    b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
    y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))
    
    import jaxopt as jopt
    jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    Input In [11], in <module>
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
        148 if ridge is not None:
        149   _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
    --> 151 Ab = _rmatvec(matvec, b)
        153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
        112 def _rmatvec(matvec, x):
        113   """Computes A^T x, from matvec(x) = A x, where A is square."""
    --> 114   transpose = jax.linear_transpose(matvec, x)
        115   return transpose(x)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
       2208 in_dtypes = map(dtypes.dtype, in_avals)
       2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
    -> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
       2212                                              instantiate=True)
       2213 out_avals, _ = unzip2(out_pvals)
       2214 out_dtypes = map(dtypes.dtype, out_avals)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
        503 with core.new_main(JaxprTrace) as main:
        504   fun = trace_to_subjaxpr(fun, main, instantiate)
    --> 505   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
        506   assert not env
        507   del main, fun, env
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
        163 gen = gen_static_args = out_store = None
        165 try:
    --> 166   ans = self.f(*args, **dict(self.params, **kwargs))
        167 except:
        168   # Some transformations yield from inside context managers, so we have to
        169   # interrupt them before reraising the exception. Otherwise they will only
        170   # get garbage-collected at some later time, running their cleanup tasks only
        171   # after this exception is handled, which can corrupt the global state.
        172   while stack:
    
    Input In [11], in <lambda>(x)
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
       4194   return lax.mul(a, b)
       4195 if _max(a_ndim, b_ndim) <= 2:
    -> 4196   return lax.dot(a, b, precision=precision)
       4198 if b_ndim == 1:
       4199   contract_dims = ((a_ndim - 1,), (0,))
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
        664   return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
        665                      precision=precision, preferred_element_type=preferred_element_type)
        666 else:
    --> 667   raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
        668       lhs.shape, rhs.shape))
    
    TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).
    
    opened by quattro 6
  • Initial stepsize not exposed in LBFGS constructor [question/bug?]

    Initial stepsize not exposed in LBFGS constructor [question/bug?]

    I see that LbfgsState contains a stepsize and that LBFGS.init_state hard-codes it to 1. I also see that the LBFGS.update method performs a line search in which the initial step size is set from this LBFGS state.

    I have a particularly ill-conditioned problem that requires tiny initial steps, but I was surprised that the initial stepsize could not be set in the LBFGS constructor or elsewhere as far as I could see. Is this an oversight or an intentional part of the design? If it's intentional, is there an idiomatic way to set an initial stepsize when using LBFGS.run that I have overlooked?

    Thanks in advance, and thanks for a really cool library.

    opened by erdmann 6
  • Infinities and NaNs in quadratic_prog when c=0

    Infinities and NaNs in quadratic_prog when c=0

    Hi,

    I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).

    However, both my research code and the following modification of this test diverge even for a single step (maxiter=1).

    The modification just involves setting c=0, so:

    def test_qp_eq_only_c_zero(self):
      Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
      c = jnp.array([0.0, 0.0]) #ONLY CHANGE
      A = jnp.array([[1.0, 1.0]])
      b = jnp.array([1.0])
      qp = QuadraticProgramming(tol=1e-7)
      hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
      sol = qp.run(**hyperparams).params
      self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
      self._check_derivative_A_and_b(qp, hyperparams, A, b)
    

    Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.

    Thanks!

    opened by FerranAlet 6
  • OptaxSolver Error: too many positional arguments

    OptaxSolver Error: too many positional arguments

    Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions. The task is to find mean for a set of vectors named X via gradient descent.

    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    
    import jax
    import jax.numpy as jnp
    from jax import grad, random, jit
    from jax import jacobian, hessian, jacfwd, jacrev
    key = random.PRNGKey(0)
    
    import jaxopt
    from jaxopt import implicit_diff
    from jaxopt import linear_solve
    from jaxopt import OptaxSolver, GradientDescent
    import optax
    
    def euclidean_distance(a, b):
        """
        Squared Euclidean distance
        """
        return jnp.inner(a - b, a - b)
    
    def weighted_distance(x, X, w):
        loss = 0
        for i, obj in enumerate(X):
            loss += w[i] * euclidean_distance(obj, x)
        return loss
    
    def identical(Y, Y_grad):
        return Y
    

    Algorithm for finding mean:

    # Mean calculation for manifolds with gradient descent
    @implicit_diff.custom_root(jax.grad(weighted_distance))
    def euclidean_weighted_mean(X_set, weights = None, lr = 0.1, n_iter = 50, plot_loss_flag = False):
        
        if weights == None:
            weights = jnp.full((X_set.shape[0]), 1) / X_set.shape[0]
    
        # init mean with random element from set
        Y = X_set[np.random.randint(0, X_set.shape[0], (1,))][0] 
        
        if plot_loss_flag:
            plot_loss = []
            prev_loss = 0
            plato_iter = 0
            plato_reached = False
        
        for i in range(n_iter):
            
            # calculate loss
            loss = weighted_distance(Y, X_set, weights)
    
            if plot_loss_flag:
                if jnp.allclose(jnp.array(loss), jnp.array(prev_loss)):
                    if not plato_reached:
                        plato_iter = i
                        plato_reached = True
                else:
                    prev_loss = loss
                    plato_reached = False
        
            Y_grad = grad(weighted_distance, argnums= 0)(Y, X_set, weights)
            
            # calculate Riemannian gradient
            riem_grad_Y = Y_grad
            
            # update Y
            Y_step = Y - lr * riem_grad_Y
            
            # project new Y on manifold with retraction
            Y = Y_step
            
            if plot_loss_flag:
              # collect loss for plotting
              plot_loss.append(loss)
        
        if plot_loss_flag:
            print(f"Total loss: {weighted_distance(Y, X_set, weights)} got in {plato_iter} iterations")    
            fig, ax = plt.subplots()
            ax.plot(plot_loss)
            ax.set_xlabel("Iteration")
            ax.set_ylabel("Loss")
            plt.show()
        return Y
    

    You can launch it like this:

    d = 2
    m = 4
    X = jax.random.uniform(key, (m,d))
    euclidean_weighted_mean(X, weights = None, lr = 1e-3, n_iter = 100, plot_loss_flag = True)
    

    As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to X[0] as possible:

    def global_task_objective(w, X, target_point, lr, n_iter):
        x = euclidean_weighted_mean(X, w, lr = lr, n_iter = n_iter)
        loss = euclidean_distance(x, target_point)
        return loss, x
    
    target_point = X[0]
    
    w_init = jnp.array(np.random.randn(X.shape[0])) * jnp.square(2 / X.shape[0]) 
    
    lr = 1e-3
    n_iter = 100
    
    global_task_objective(w_init, X, target_point, lr, n_iter)
    solver = OptaxSolver(opt=optax.amsgrad(1e-2), fun=global_task_objective, has_aux=True)
    state = solver.init_state(w_init, X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    

    The problem emerges when I call

    w_init, state = solver.update(params=w_init, 
                                 state=state, 
                                 X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    
    image Meanwhile the official example with Ridge regression works perfectly. Any suggestions?
    opened by MarioAuditore 0
  • Custom loop pjit example

    Custom loop pjit example

    A MWE of how jax.experimental.pjit can be used in JAXopt (see also PR #346).

    NOTE: jax.experimental.pjit is not yet supported in Colab. However, this example illustrates how users with access to Google Cloud TPUs may use jax.experimental.pjit in combination with JAXopt solvers.

    pull ready 
    opened by fllinares 2
  • Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    This fixes #351 .

    @mblondel I couldn't use your suggestion of creating a new type of init LBFGSInit because the init_params variable is used for both init_state and update. Therefore I would have had to add case distinctions in the 2 functions which seemed unreasonable. Rather I took the approach I saw in some other iterative solvers which was to add an extra keyword argument to init_state, update and _value_and_grad_fun.

    I added a test to make sure that this runs, but I am not sure whether we need to add a test to make sure that it improves some cases. I also don't know whether we should test that differentiation is ok.

    opened by zaccharieramzi 5
  • Enable warm-starting the hessian approximation in L-BFGS

    Enable warm-starting the hessian approximation in L-BFGS

    Currently one can only provide an initial estimate of the solution, enable warm start of the iterates. But for quasi-Newton methods, it can also be a good idea to provide initial estimates of the hessian approximation, typically when solving multiple time a similar problem.

    This was for example done in HOAG by @fabianp (see https://github.com/fabianp/hoag/blob/master/hoag/hoag.py#L109).

    I am willing to implement this in the next few weeks.

    As I know it is of interest to them as well, cc-ing @marius311 and @mblondel

    opened by zaccharieramzi 2
  • Batched QP (and other optimization algorithm)

    Batched QP (and other optimization algorithm)

    I'm trying to make OSQP batchable (so I can make it a layer in neural networks, like OptNet), but I couldn't find any documentation yet about using vmap to solve batched version of optimization problems.

    opened by jn-tang 1
Releases(jaxopt-v0.5.5)
  • jaxopt-v0.5.5(Oct 20, 2022)

    New features

    • Added MAML example by Fabian Pedregosa based on initial code by Paul Vicol and Eric Jiang.
    • Added the possibility to stop LBFGS after a line search failure, by Zaccharie Ramzi.
    • Added gamma to LBFGS state, by Zaccharie Ramzi.
    • Added jaxopt.BFGS, by Mathieu Blondel.
    • Added value_and_grad option to all gradient-based solvers, by Mathieu Blondel.
    • Added Fenchel-Young loss, by Quentin Berthet.
    • Added projection_sparse_simplex, by Tianlin Liu.

    Bug fixes and enhancements

    • Fixed missing args,kwargs in resnet example, by Louis Béthune.
    • Corrected the implicit diff examples, by Zaccharie Ramzi.
    • Small optimization in l2-regularized semi-dual OT, by Mathieu Blondel.
    • Numerical stability improvements in jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Dtype consistency in LBFGS, by Alex Botev.

    Deprecations

    • jaxopt.QuadraticProgramming is now fully removed. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Alex Botev, Amir Saadat, Fabian Pedregosa, Louis Béthune, Mathieu Blondel, Quentin Berthet, Tianlin Liu, Zaccharie Ramzi.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.5(Aug 30, 2022)

    New features

    • Added optimal transport related projections: projection_transport, projection_birkhoff, kl_projection_transport, and kl_projection_birkhoff, by Mathieu Blondel (semi-dual formulation) and Tianlin Liu (dual formulation).

    Bug fixes and enhancements

    • Fix LaTeX rendering issue in notebooks, by Amélie Héliou.
    • Avoid gradient recompilations in zoom line search, by Mathieu Blondel.
    • Fix unused Jacobian issue in jaxopt.ScipyRootFinding, by Louis Béthune.
    • Use zoom line search by default in jaxopt.LBFGS and jaxopt.NonlinearCG, by Mathieu Blondel.
    • Pass tolerance argument to jaxopt.ScipyMinimize, by pipme.
    • Handle has_aux in jaxopt.LevenbergMarquardt, by Keunhong Park.
    • Add maxiter keyword argument in jaxopt.ScipyMinimize, by Fabian Pedregosa.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Keunhong Park, Fabian Pedregosa, pipme.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.3(Jun 28, 2022)

    New features

    • Added zoom line search in jaxopt.LBFGS, by Mathieu Blondel. It can be enabled with the linesearch="zoom" option.

    Bug fixes and enhancements

    • Added support for quadratic polynomial fun in jaxopt.BoxOSQP and jaxopt.OSQP, by Louis Béthune.
    • Added a notebook for the dataset distillation example, by Amélie Héliou.
    • Fixed wrong links and deprecation warnings in notebooks, by Fabian Pedregosa.
    • Changed losses to avoid roundoff, by Jack Valmadre.
    • Fixed init_params bug in multiclass_svm example, by Louis Béthune.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Fabian Pedregosa, Jack Valmadre.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.2(Jun 10, 2022)

  • jaxopt-v0.4.1(Jun 10, 2022)

    Bug fixes and enhancements

    • Improvements in jaxopt.LBFGS: fixed bug when using use_gamma=True, added stepsize option, strengthened tests, by Mathieu Blondel.
    • Fixed link in resnet notebook, by Fabian Pedregosa.

    Contributors

    Fabian Pedregosa, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4(May 24, 2022)

    New features

    • Added solver jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Added solver jaxopt.BoxCDQP, by Mathieu Blondel.
    • Added projection_hypercube, by Mathieu Blondel.

    Bug fixes and enhancements

    • Fixed solve_normal_cg when the linear operator is “nonsquare” (does not map to a space of same dimension), by Mathieu Blondel.
    • Fixed edge case in jaxopt.Bisection, by Mathieu Blondel.
    • Replaced deprecated tree_multimap with tree_map, by Fan Yang.
    • Added support for leaf cond pytrees in tree_where, by Felipe Llinares.
    • Added Python 3.10 support officially, by Jeppe Klitgaard.
    • In scipy wrappers, converted pytree leaves to jax arrays to determine their shape/dtype, by Roy Frostig.
    • Converted the “Resnet” and “Adversarial Training” examples to notebooks, by Fabian Pedregosa.

    Contributors

    Amir Saadat, Fabian Pedregosa, Fan Yang, Felipe Llinares, Jeppe Klitgaard, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3.1(Feb 28, 2022)

    New features

    • Pjit-based example of data parallel training using Flax, by Felipe Llinares.

    Bug fixes and enhancements

    • Support for GPU and state of the art adversarial training algorithm (PGD) on the robust_training.py example, by Fabian Pedregosa.
    • Update line search in LBFGS to use jit and unroll from LBFGS, by Ian Williamson.
    • Support dynamic maximum iteration count in iterative solvers, by Roy Frostig.
    • Fix tree_where for singleton pytrees, by Louis Béthune.
    • Remove QuadraticProg in projections and set init_params=None by default in QP solvers, by Louis Béthune.
    • Add missing 'value' attribute in LbfgsState, by Mathieu Blondel.

    Contributors

    Felipe Llinares, Fabian Pedregosa, Ian Williamson, Louis Bétune, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3(Jan 31, 2022)

    New features

    • jaxopt.LBFGS
    • jaxopt.BacktrackingLineSearch
    • jaxopt.GaussNewton
    • jaxopt.NonlinearCG

    Bug fixes and enhancements

    • Support implicit AD in higher-order differentiation.

    Contributors

    Amir Saadat, Fabian Pedregosa, Geoffrey Négiar, Hyunsung Lee, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.2(Dec 18, 2021)

    New features

    • Quadratic programming solvers jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP
    • Iterative refinement

    New examples

    • Resnet example with Flax and JAXopt.

    Bug fixes and enhancements

    • Prevent recompilation of loops in solver.run if executing without jit.
    • Prevents recomputation of gradient in OptaxSolver.
    • Make solver.update jittable and ensure output states are consistent.
    • Allow Callable for the stepsize argument in jaxopt.ProximalGradient, jaxopt.ProjectedGradient and jaxopt.GradientDescent.

    Deprecated features

    • jaxopt.QuadraticProgramming is deprecated and will be removed in v0.3. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Geoffrey Negiar, Louis Bethune, Mathieu Blondel, Vikas Sindhwani.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1.1(Oct 19, 2021)

    New features

    • Added solver jaxopt.ArmijoSGD
    • Added example Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.
    • Added example Comparison of different SGD algorithms.

    Bug fixes

    • Allow non-jittable proximity operators in jaxopt.ProximalGradient
    • Raise an exception if a quadratic program is infeasible or unbounded

    Contributors

    Fabian Pedregosa, Louis Bethune, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1(Oct 14, 2021)

    Classes

    • jaxopt.AndersonAcceleration
    • jaxopt.AndersonWrapper
    • jaxopt.Bisection
    • jaxopt.BlockCoordinateDescent
    • jaxopt.FixedPointIteration
    • jaxopt.GradientDescent
    • jaxopt.MirrorDescent
    • jaxopt.OptaxSolver
    • jaxopt.PolyakSGD
    • jaxopt.ProjectedGradient
    • jaxopt.ProximalGradient
    • jaxopt.QuadraticProgramming
    • jaxopt.ScipyBoundedLeastSquares
    • jaxopt.ScipyBoundedMinimize
    • jaxopt.ScipyLeastSquares
    • jaxopt.ScipyMinimize
    • jaxopt.ScipyRootFinding
    • Implicit differentiation

    Examples

    • Binary kernel SVM with intercept.
    • Image classification example with Flax and JAXopt.
    • Image classification example with Haiku and JAXopt.
    • VAE example with Haiku and JAXopt.
    • Implicit differentiation of lasso.
    • Multiclass linear SVM (without intercept).
    • Non-negative matrix factorizaton (NMF) using alternating minimization.
    • Dataset distillation.
    • Implicit differentiation of ridge regression.
    • Robust training.
    • Anderson acceleration of gradient descent.
    • Anderson acceleration of block coordinate descent.
    • Anderson acceleration in application to Picard–Lindelöf theorem.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Robert Gower, Louis Bethune, Marco Cuturi, Mathieu Blondel, Peter Hawkins, Quentin Berthet, Roy Frostig, Ta-Chu Kao

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
BOVText: A Large-Scale, Multidimensional Multilingual Dataset for Video Text Spotting

BOVText: A Large-Scale, Bilingual Open World Dataset for Video Text Spotting Updated on December 10, 2021 (Release all dataset(2021 videos)) Updated o

weijiawu 47 Dec 26, 2022
This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust.

Demo BERT ONNX pipeline written in rust This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust. R

Xavier Tao 14 Dec 17, 2022
Graph-total-spanning-trees - A Python script to get total number of Spanning Trees in a Graph

Total number of Spanning Trees in a Graph This is a python script just written f

Mehdi I. 0 Jul 18, 2022
How to train a CNN to 99% accuracy on MNIST in less than a second on a laptop

Training a NN to 99% accuracy on MNIST in 0.76 seconds A quick study on how fast you can reach 99% accuracy on MNIST with a single laptop. Our answer

Tuomas Oikarinen 42 Dec 10, 2022
Autonomous racing with the Anki Overdrive

Anki Autonomous Racing Autonomous racing with the Anki Overdrive. Using the Overdrive-Python API (https://github.com/xerodotc/overdrive-python) develo

3 Dec 11, 2022
Real time Human Detection Counting

In this python project, we are going to build the Human Detection and Counting System through Webcam or you can give your own video or images. This is a deep learning project on computer vision, whic

Mir Nawaz Ahmad 2 Jun 17, 2022
a reccurrent neural netowrk that when trained on a peice of text and fed a starting prompt will write its on 250 character text using LSTM layers

RNN-Playwrite a reccurrent neural netowrk that when trained on a peice of text and fed a starting prompt will write its on 250 character text using LS

Arno Barton 1 Oct 29, 2021
Reinforcement Learning via Supervised Learning

Reinforcement Learning via Supervised Learning Installation Run pip install -e . in an environment with Python = 3.7.0, 3.9. The code depends on MuJ

Scott Emmons 49 Nov 28, 2022
An ML & Correlation platform for transforming disparate data points of interest into usable intelligence.

SSIDprobeCollector An ML & Correlation platform for transforming disparate data points of interest into usable intelligence. At a High level the platf

Bill Reyor 1 Jan 30, 2022
Official implementation for the paper: Multi-label Classification with Partial Annotations using Class-aware Selective Loss

Multi-label Classification with Partial Annotations using Class-aware Selective Loss Paper | Pretrained models Official PyTorch Implementation Emanuel

99 Dec 27, 2022
A convolutional recurrent neural network for classifying A/B phases in EEG signals recorded for sleep analysis.

CAP-Classification-CRNN A deep learning model based on Inception modules paired with gated recurrent units (GRU) for the classification of CAP phases

Apurva R. Umredkar 2 Nov 25, 2022
OpenCV, MediaPipe Pose Estimation, Affine Transform for Icon Overlay

Yoga Pose Identification and Icon Matching Project Goal Detect yoga poses performed by a user and overlay a corresponding icon image. Running the main

Anna Garverick 1 Dec 03, 2021
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
A collection of Reinforcement Learning algorithms from Sutton and Barto's book and other research papers implemented in Python.

Reinforcement-Learning-Notebooks A collection of Reinforcement Learning algorithms from Sutton and Barto's book and other research papers implemented

Pulkit Khandelwal 1k Dec 28, 2022
Convolutional Neural Network for 3D meshes in PyTorch

MeshCNN in PyTorch SIGGRAPH 2019 [Paper] [Project Page] MeshCNN is a general-purpose deep neural network for 3D triangular meshes, which can be used f

Rana Hanocka 1.4k Jan 04, 2023
ANN model for prediction a spatio-temporal distribution of supercooled liquid in mixed-phase clouds using Doppler cloud radar spectra.

VOODOO Revealing supercooled liquid beyond lidar attenuation Explore the docs » Report Bug · Request Feature Table of Contents About The Project Built

remsens-lim 2 Apr 28, 2022
Minimal implementation of Denoised Smoothing: A Provable Defense for Pretrained Classifiers in TensorFlow.

Denoised-Smoothing-TF Minimal implementation of Denoised Smoothing: A Provable Defense for Pretrained Classifiers in TensorFlow. Denoised Smoothing is

Sayak Paul 19 Dec 11, 2022
A modular PyTorch library for optical flow estimation using neural networks

A modular PyTorch library for optical flow estimation using neural networks

neu-vig 113 Dec 20, 2022
LWCC: A LightWeight Crowd Counting library for Python that includes several pretrained state-of-the-art models.

LWCC: A LightWeight Crowd Counting library for Python LWCC is a lightweight crowd counting framework for Python. It wraps four state-of-the-art models

Matija Teršek 39 Dec 28, 2022
The comma.ai Calibration Challenge!

Welcome to the comma.ai Calibration Challenge! Your goal is to predict the direction of travel (in camera frame) from provided dashcam video. This rep

comma.ai 697 Jan 05, 2023