Hardware accelerated, batchable and differentiable optimizers in JAX.

Overview

JAXopt

Installation | Examples | References

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

Installation

JAXopt can be installed with pip directly from github, with the following command:

$ pip install git+https://github.com/google/jaxopt

Alternatively, it can be be installed from sources with the following command:

$ python setup.py install

References

Our implicit differentiation framework is described in this paper. To cite it:

@article{jaxopt_implicit_diff,
  title={Efficient and Modular Implicit Differentiation},
  author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe},
  journal={arXiv preprint arXiv:2105.15183},
  year={2021}
}

Disclaimer

JAXopt is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

Comments
  • Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat Running exceptionally slow unless verbose is enabled

    Levenberg-Mardquat optimizer runs exceptionally slow (~30 seconds for 30 iterations) until I turn on verbose==True (~1 second for 30 iterations. Any idea what may be going on? enabling JIT seems to have no impact. Was hoping to use this for a real-time system but even at 1 second things are way too slow.

    opened by pablovela5620 23
  • implementation of Fletcher-Reeves Algorithm

    implementation of Fletcher-Reeves Algorithm

    • Polak-Ribiere Method; To my knowledge, it was quite successful to use conjugate gradient variants on general nonconstrained optimization

    This PR depends on Line Search of PR #128.

    • Beta division is required to guarantee strong Wolfe Condition, but (i don't know) it raises error..
    pull ready 
    opened by ita9naiwa 17
  • vmap support in QPs

    vmap support in QPs

    Hi, I experience some pb with projection_polyhedron

    import numpy as np
    import matplotlib.pyplot as plt
    
    import jax
    import jax.numpy as jnp
    
    import jaxopt
    from jaxopt.projection import projection_l2_ball, projection_box, projection_l1_ball, projection_polyhedron
    
    def myproj3(x):
        A = jnp.array([[1.0, 1.0]])
        b = jnp.array([1.0])
        G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
        h = jnp.array([0.0, 0.0])    
        x = projection_polyhedron(x,hyperparams = (A, b, G, h))
        return x
    
    rng_key = jax.random.PRNGKey(42)
    x = jax.random.uniform(rng_key, (5000,2), minval=-3,maxval=3)
    p1_x=jax.vmap(myproj3, in_axes=(0,None))(x)
    fig, ax = plt.subplots(figsize=(5,5))
    ax.scatter(x[:,0],x[:,1],s=0.5)
    ax.scatter(p1_x[:,0],p1_x[:,1],s=0.5,c='g')
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    plot.show()
    

    First, I had to install cvxpy #!pip install cvxpy Then, I got this error

    TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[2])>with<BatchTrace(level=1/1)>
      with val = DeviceArray([[-2.37103211,  2.33759997],
                              [ 2.76953806, -2.37750394],
                              [-0.87246632,  0.73224625],
                              ...,
                              [ 2.29799773,  2.81894884],
                              [ 2.4022714 ,  0.80693103],
                              [-0.41563116,  2.83898531]], dtype=float64)
           batch_dim = 0
    

    Is anyone has an hint? Thanks

    enhancement 
    opened by jecampagne 12
  • KKT conditions when the primal solution is a pytree

    KKT conditions when the primal solution is a pytree

    Hi, Congrats on the great tool! Inspired by the QuadraticProgramming example I built a code that differentiates through KKT conditions. My code works whenever the primal solution variable is a jnp array, but not when it's a generic pytree. Giving me the following issue:

    TypeError: Tree structure of cotangent input PyTreeDef(([(*, *), (), (*, *)], *, None)), does not match structure of primal output PyTreeDef(([(*, *), (), (*, *), (*, *), (), (*, *)], *, None))

    where I'm pretty sure [(*, *), (), (*, *)] represents the primal solution and PyTreeDef(([(*, *), (), (*, *)], *, None)) could represent the optimality function.

    I was able to make it work by storing the primal solution in a single jnp array and reshaping it into the appropriate pytree whenever needed, but it's not clean or efficient. I was wondering if there's a bug in the current codebase (I only found tests for single jnp arrays) or I'm misusing the interface (I'm not a jax expert).

    To make it easier to reproduce I modified the quadratic_prog.py file by making the model return a list of one array instead of an array for the primal variables (leaving both dual variables the same). Then I modified the obj_fun, eq_fun and ineq_fun to use primal_var[0] instead of primal_var. If I understand correctly, this should still work. However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is: ([DeviceArray([ 0.43999994, -1.3199999 ], dtype=float32), DeviceArray([-0.44000003, 1.32 ], dtype=float32)], DeviceArray([2.9802322e-08], dtype=float32), None)

    Looking at the numbers of the problem I believe [0.44,-1.32] is the gradient of the obj_fun w.r.t. the primal and [-0.44,+1.32] the gradient of the equality constraint w.r.t. the primal times the dual. They should have been added up together to have [0,0] as expected. I feel this may be fundamentally the same problem I was facing in my own research code since there I also found one of the values had the shape of the primal variable twice instead of once.

    Notice also thatthe test on the line just above (checking that the primal solution is correct) still holds provided we check sol[0][0] instead of sol[0] (since sol[0] is now a 1-element list).

    Is differentiation through KKT supposed to work for general pytrees? If so, what should I have done to make it work in the quadratic_prog.py example?

    Thanks!

    opened by FerranAlet 11
  • Hot fix: corrected condition in lbfgs

    Hot fix: corrected condition in lbfgs

    The feature I had introduced in https://github.com/google/jaxopt/pull/323 was failing when the run function was jitted and was a no-op when not because of the following reason:

     ~True == -2  # this is True
    

    Therefore when jitted it was complaining about different types in a condition function, and when not jitted it was equivalent to always being False.

    EDIT

    Actually I am still running into an error when jitted, so will continue to investigate.

    The gist of the error is Abstract tracer value encountered where concrete value is expected, basically doing (not self.stop_if_linesearch_fails | ~state.failed_linesearch) is not allowed because one is a bool and the other is an abstract value.

    pull ready 
    opened by zaccharieramzi 7
  • Issue with gradients wrt optimality fn parameters through root finding vjp

    Issue with gradients wrt optimality fn parameters through root finding vjp

    First of all, thanks a lot for this library! Really useful tools! I'm interested in getting at least 2nd order gradients through root finding, and I'm finding an odd behavior that I wanted to report.

    Maybe I'm doing something wrong, but in the following schematic case I silently get the wrong gradients:

    def inv_f(x, aux):
      bisec = Bisection(optimality_fun=F, lower=0.0, upper=1., 
                        check_bracket=False, unroll=True)
      return bisec.run(aux=aux).params
    
    # Here I extract the value part of the vjp, but the grad part also gives wrong results
    test_fn = lambda aux: jax.value_and_grad(inv_f)(0.5, aux)[0] 
    
    jax.grad(test_fn)(1.) # Returns 0 instead of the expected gradients
    

    Here I'm only trying to get gradients of the value returned by jax.value_and_grad, but the gradients of the gradients returned by jax.value_and_grad are also wrong (but not as obvious).

    I made a small demo notebook that reproduces this issue here.

    As a reference I've also implemented my own implicit gradients, bypassing the jaxopt ones, and they seem to give me the correct answer.

    Reading the source code of jaxopt, it is not immediatly obvious to me why this doesn't work... Sorry I couldn't directly suggest a PR, but I hope this report is still useful (and that I'm not just using jaxopt wrong).

    bug 
    opened by EiffL 7
  • misc improvements to robust training example

    misc improvements to robust training example

    main changes:

    • Fixes #134 by normalizing in-place.
    • Plot convergence curves for both clean and adversarial accuracy.
    • Replace the fast-sign-gradient method by the much more powerful PGD method.
    • Be able to select different datasets.
    • Homogeneize API wrt to the other examples. For example, this now uses the same load_dataset, CNN, loss_fun, accuracy than flax_image_classif.py . Most of the command line flags have also been homogeneized.
    pull ready 
    opened by fabianp 7
  • Bisection hanging

    Bisection hanging

    I am trying to jaxopt.Bisection to replace the use of scipy.optimize.bisect in a computational model but Bisection hangs when I run my code.

    The basic structure includes 2 functions that are both jitted (so I assume it should be able to compile ok):

    @jit
    def f1(parameters):
        ....
        return jax.numpy.array([a,b,c])
    
    @jit
    def opt_fun(x):
        f1(x,params)
        .... 
        return float_value
    

    when I call scipy.optimize.bisect(opt_fun,x0,x1) it runs with no issue but jaxopt.Bisection(opt_fun,x0,x1).run(None) hangs with with~10% cpu usage and55% memory usage on i9 2018 macbook pro with 32GB of memory.

    I acknowledge I may be using this incorrectly and that this is possibly not the intended use case but any direction would be very helpful. My intention is to use this computational model with numpyro in the future and having a jax version of the bisection root finding would be incredibly helpful.

    opened by jjruby09 7
  • Incompatible shape in solve_normal_cg

    Incompatible shape in solve_normal_cg

    When A.shape = (N, P) for N != P, I run into shape errors when trying to use solve_normal_cg for fitting the normal equations.

    I have a small reproducible example below for N > P, but the error holds for when P > N.

    import jax.numpy as jnp
    import numpy as np
    N = 1000
    P = 3
    prob = np.random.uniform(0.01, 0.5, size=P)
    h2g = 0.1
    X = np.random.binomial(2, p=prob, size=(N, P))
    b = np.random.normal(size=(P)) * np.sqrt(h2g / P)
    y = X @ b + np.sqrt(1 - h2g) * np.random.normal(size=(N,))
    
    import jaxopt as jopt
    jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    Input In [11], in <module>
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:151, in solve_normal_cg(matvec, b, ridge, init, **kwargs)
        148 if ridge is not None:
        149   _matvec = _make_ridge_matvec(_matvec, ridge=ridge)
    --> 151 Ab = _rmatvec(matvec, b)
        153 return jax.scipy.sparse.linalg.cg(_matvec, Ab, x0=init, **kwargs)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jaxopt/_src/linear_solve.py:114, in _rmatvec(matvec, x)
        112 def _rmatvec(matvec, x):
        113   """Computes A^T x, from matvec(x) = A x, where A is square."""
    --> 114   transpose = jax.linear_transpose(matvec, x)
        115   return transpose(x)[0]
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/api.py:2211, in linear_transpose(fun, reduce_axes, *primals)
       2208 in_dtypes = map(dtypes.dtype, in_avals)
       2210 in_pvals = map(pe.PartialVal.unknown, in_avals)
    -> 2211 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
       2212                                              instantiate=True)
       2213 out_avals, _ = unzip2(out_pvals)
       2214 out_dtypes = map(dtypes.dtype, out_avals)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:505, in trace_to_jaxpr(fun, pvals, instantiate)
        503 with core.new_main(JaxprTrace) as main:
        504   fun = trace_to_subjaxpr(fun, main, instantiate)
    --> 505   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
        506   assert not env
        507   del main, fun, env
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
        163 gen = gen_static_args = out_store = None
        165 try:
    --> 166   ans = self.f(*args, **dict(self.params, **kwargs))
        167 except:
        168   # Some transformations yield from inside context managers, so we have to
        169   # interrupt them before reraising the exception. Otherwise they will only
        170   # get garbage-collected at some later time, running their cleanup tasks only
        171   # after this exception is handled, which can corrupt the global state.
        172   while stack:
    
    Input In [11], in <lambda>(x)
    ----> 1 jopt.linear_solve.solve_normal_cg(lambda x: jnp.dot(X, x), y)
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4196, in dot(a, b, precision)
       4194   return lax.mul(a, b)
       4195 if _max(a_ndim, b_ndim) <= 2:
    -> 4196   return lax.dot(a, b, precision=precision)
       4198 if b_ndim == 1:
       4199   contract_dims = ((a_ndim - 1,), (0,))
    
    File ~/miniconda3/lib/python3.9/site-packages/jax/_src/lax/lax.py:667, in dot(lhs, rhs, precision, preferred_element_type)
        664   return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
        665                      precision=precision, preferred_element_type=preferred_element_type)
        666 else:
    --> 667   raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
        668       lhs.shape, rhs.shape))
    
    TypeError: Incompatible shapes for dot: got (1000, 3) and (1000,).
    
    opened by quattro 6
  • Initial stepsize not exposed in LBFGS constructor [question/bug?]

    Initial stepsize not exposed in LBFGS constructor [question/bug?]

    I see that LbfgsState contains a stepsize and that LBFGS.init_state hard-codes it to 1. I also see that the LBFGS.update method performs a line search in which the initial step size is set from this LBFGS state.

    I have a particularly ill-conditioned problem that requires tiny initial steps, but I was surprised that the initial stepsize could not be set in the LBFGS constructor or elsewhere as far as I could see. Is this an oversight or an intentional part of the design? If it's intentional, is there an idiomatic way to set an initial stepsize when using LBFGS.run that I have overlooked?

    Thanks in advance, and thanks for a really cool library.

    opened by erdmann 6
  • Infinities and NaNs in quadratic_prog when c=0

    Infinities and NaNs in quadratic_prog when c=0

    Hi,

    I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).

    However, both my research code and the following modification of this test diverge even for a single step (maxiter=1).

    The modification just involves setting c=0, so:

    def test_qp_eq_only_c_zero(self):
      Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
      c = jnp.array([0.0, 0.0]) #ONLY CHANGE
      A = jnp.array([[1.0, 1.0]])
      b = jnp.array([1.0])
      qp = QuadraticProgramming(tol=1e-7)
      hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
      sol = qp.run(**hyperparams).params
      self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
      self._check_derivative_A_and_b(qp, hyperparams, A, b)
    

    Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.

    Thanks!

    opened by FerranAlet 6
  • OptaxSolver Error: too many positional arguments

    OptaxSolver Error: too many positional arguments

    Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions. The task is to find mean for a set of vectors named X via gradient descent.

    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    
    import jax
    import jax.numpy as jnp
    from jax import grad, random, jit
    from jax import jacobian, hessian, jacfwd, jacrev
    key = random.PRNGKey(0)
    
    import jaxopt
    from jaxopt import implicit_diff
    from jaxopt import linear_solve
    from jaxopt import OptaxSolver, GradientDescent
    import optax
    
    def euclidean_distance(a, b):
        """
        Squared Euclidean distance
        """
        return jnp.inner(a - b, a - b)
    
    def weighted_distance(x, X, w):
        loss = 0
        for i, obj in enumerate(X):
            loss += w[i] * euclidean_distance(obj, x)
        return loss
    
    def identical(Y, Y_grad):
        return Y
    

    Algorithm for finding mean:

    # Mean calculation for manifolds with gradient descent
    @implicit_diff.custom_root(jax.grad(weighted_distance))
    def euclidean_weighted_mean(X_set, weights = None, lr = 0.1, n_iter = 50, plot_loss_flag = False):
        
        if weights == None:
            weights = jnp.full((X_set.shape[0]), 1) / X_set.shape[0]
    
        # init mean with random element from set
        Y = X_set[np.random.randint(0, X_set.shape[0], (1,))][0] 
        
        if plot_loss_flag:
            plot_loss = []
            prev_loss = 0
            plato_iter = 0
            plato_reached = False
        
        for i in range(n_iter):
            
            # calculate loss
            loss = weighted_distance(Y, X_set, weights)
    
            if plot_loss_flag:
                if jnp.allclose(jnp.array(loss), jnp.array(prev_loss)):
                    if not plato_reached:
                        plato_iter = i
                        plato_reached = True
                else:
                    prev_loss = loss
                    plato_reached = False
        
            Y_grad = grad(weighted_distance, argnums= 0)(Y, X_set, weights)
            
            # calculate Riemannian gradient
            riem_grad_Y = Y_grad
            
            # update Y
            Y_step = Y - lr * riem_grad_Y
            
            # project new Y on manifold with retraction
            Y = Y_step
            
            if plot_loss_flag:
              # collect loss for plotting
              plot_loss.append(loss)
        
        if plot_loss_flag:
            print(f"Total loss: {weighted_distance(Y, X_set, weights)} got in {plato_iter} iterations")    
            fig, ax = plt.subplots()
            ax.plot(plot_loss)
            ax.set_xlabel("Iteration")
            ax.set_ylabel("Loss")
            plt.show()
        return Y
    

    You can launch it like this:

    d = 2
    m = 4
    X = jax.random.uniform(key, (m,d))
    euclidean_weighted_mean(X, weights = None, lr = 1e-3, n_iter = 100, plot_loss_flag = True)
    

    As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to X[0] as possible:

    def global_task_objective(w, X, target_point, lr, n_iter):
        x = euclidean_weighted_mean(X, w, lr = lr, n_iter = n_iter)
        loss = euclidean_distance(x, target_point)
        return loss, x
    
    target_point = X[0]
    
    w_init = jnp.array(np.random.randn(X.shape[0])) * jnp.square(2 / X.shape[0]) 
    
    lr = 1e-3
    n_iter = 100
    
    global_task_objective(w_init, X, target_point, lr, n_iter)
    solver = OptaxSolver(opt=optax.amsgrad(1e-2), fun=global_task_objective, has_aux=True)
    state = solver.init_state(w_init, X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    

    The problem emerges when I call

    w_init, state = solver.update(params=w_init, 
                                 state=state, 
                                 X=X, target_point=target_point, lr=lr, n_iter=n_iter)
    
    image Meanwhile the official example with Ridge regression works perfectly. Any suggestions?
    opened by MarioAuditore 0
  • Custom loop pjit example

    Custom loop pjit example

    A MWE of how jax.experimental.pjit can be used in JAXopt (see also PR #346).

    NOTE: jax.experimental.pjit is not yet supported in Colab. However, this example illustrates how users with access to Google Cloud TPUs may use jax.experimental.pjit in combination with JAXopt solvers.

    pull ready 
    opened by fllinares 2
  • Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    Added a new API allowing to warm start the inverse Hessian approximation in LBFGS

    This fixes #351 .

    @mblondel I couldn't use your suggestion of creating a new type of init LBFGSInit because the init_params variable is used for both init_state and update. Therefore I would have had to add case distinctions in the 2 functions which seemed unreasonable. Rather I took the approach I saw in some other iterative solvers which was to add an extra keyword argument to init_state, update and _value_and_grad_fun.

    I added a test to make sure that this runs, but I am not sure whether we need to add a test to make sure that it improves some cases. I also don't know whether we should test that differentiation is ok.

    opened by zaccharieramzi 5
  • Enable warm-starting the hessian approximation in L-BFGS

    Enable warm-starting the hessian approximation in L-BFGS

    Currently one can only provide an initial estimate of the solution, enable warm start of the iterates. But for quasi-Newton methods, it can also be a good idea to provide initial estimates of the hessian approximation, typically when solving multiple time a similar problem.

    This was for example done in HOAG by @fabianp (see https://github.com/fabianp/hoag/blob/master/hoag/hoag.py#L109).

    I am willing to implement this in the next few weeks.

    As I know it is of interest to them as well, cc-ing @marius311 and @mblondel

    opened by zaccharieramzi 2
  • Batched QP (and other optimization algorithm)

    Batched QP (and other optimization algorithm)

    I'm trying to make OSQP batchable (so I can make it a layer in neural networks, like OptNet), but I couldn't find any documentation yet about using vmap to solve batched version of optimization problems.

    opened by jn-tang 1
Releases(jaxopt-v0.5.5)
  • jaxopt-v0.5.5(Oct 20, 2022)

    New features

    • Added MAML example by Fabian Pedregosa based on initial code by Paul Vicol and Eric Jiang.
    • Added the possibility to stop LBFGS after a line search failure, by Zaccharie Ramzi.
    • Added gamma to LBFGS state, by Zaccharie Ramzi.
    • Added jaxopt.BFGS, by Mathieu Blondel.
    • Added value_and_grad option to all gradient-based solvers, by Mathieu Blondel.
    • Added Fenchel-Young loss, by Quentin Berthet.
    • Added projection_sparse_simplex, by Tianlin Liu.

    Bug fixes and enhancements

    • Fixed missing args,kwargs in resnet example, by Louis Béthune.
    • Corrected the implicit diff examples, by Zaccharie Ramzi.
    • Small optimization in l2-regularized semi-dual OT, by Mathieu Blondel.
    • Numerical stability improvements in jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Dtype consistency in LBFGS, by Alex Botev.

    Deprecations

    • jaxopt.QuadraticProgramming is now fully removed. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Alex Botev, Amir Saadat, Fabian Pedregosa, Louis Béthune, Mathieu Blondel, Quentin Berthet, Tianlin Liu, Zaccharie Ramzi.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.5(Aug 30, 2022)

    New features

    • Added optimal transport related projections: projection_transport, projection_birkhoff, kl_projection_transport, and kl_projection_birkhoff, by Mathieu Blondel (semi-dual formulation) and Tianlin Liu (dual formulation).

    Bug fixes and enhancements

    • Fix LaTeX rendering issue in notebooks, by Amélie Héliou.
    • Avoid gradient recompilations in zoom line search, by Mathieu Blondel.
    • Fix unused Jacobian issue in jaxopt.ScipyRootFinding, by Louis Béthune.
    • Use zoom line search by default in jaxopt.LBFGS and jaxopt.NonlinearCG, by Mathieu Blondel.
    • Pass tolerance argument to jaxopt.ScipyMinimize, by pipme.
    • Handle has_aux in jaxopt.LevenbergMarquardt, by Keunhong Park.
    • Add maxiter keyword argument in jaxopt.ScipyMinimize, by Fabian Pedregosa.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Keunhong Park, Fabian Pedregosa, pipme.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.3(Jun 28, 2022)

    New features

    • Added zoom line search in jaxopt.LBFGS, by Mathieu Blondel. It can be enabled with the linesearch="zoom" option.

    Bug fixes and enhancements

    • Added support for quadratic polynomial fun in jaxopt.BoxOSQP and jaxopt.OSQP, by Louis Béthune.
    • Added a notebook for the dataset distillation example, by Amélie Héliou.
    • Fixed wrong links and deprecation warnings in notebooks, by Fabian Pedregosa.
    • Changed losses to avoid roundoff, by Jack Valmadre.
    • Fixed init_params bug in multiclass_svm example, by Louis Béthune.

    Contributors

    Louis Béthune, Mathieu Blondel, Amélie Héliou, Fabian Pedregosa, Jack Valmadre.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4.2(Jun 10, 2022)

  • jaxopt-v0.4.1(Jun 10, 2022)

    Bug fixes and enhancements

    • Improvements in jaxopt.LBFGS: fixed bug when using use_gamma=True, added stepsize option, strengthened tests, by Mathieu Blondel.
    • Fixed link in resnet notebook, by Fabian Pedregosa.

    Contributors

    Fabian Pedregosa, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.4(May 24, 2022)

    New features

    • Added solver jaxopt.LevenbergMarquardt, by Amir Saadat.
    • Added solver jaxopt.BoxCDQP, by Mathieu Blondel.
    • Added projection_hypercube, by Mathieu Blondel.

    Bug fixes and enhancements

    • Fixed solve_normal_cg when the linear operator is “nonsquare” (does not map to a space of same dimension), by Mathieu Blondel.
    • Fixed edge case in jaxopt.Bisection, by Mathieu Blondel.
    • Replaced deprecated tree_multimap with tree_map, by Fan Yang.
    • Added support for leaf cond pytrees in tree_where, by Felipe Llinares.
    • Added Python 3.10 support officially, by Jeppe Klitgaard.
    • In scipy wrappers, converted pytree leaves to jax arrays to determine their shape/dtype, by Roy Frostig.
    • Converted the “Resnet” and “Adversarial Training” examples to notebooks, by Fabian Pedregosa.

    Contributors

    Amir Saadat, Fabian Pedregosa, Fan Yang, Felipe Llinares, Jeppe Klitgaard, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3.1(Feb 28, 2022)

    New features

    • Pjit-based example of data parallel training using Flax, by Felipe Llinares.

    Bug fixes and enhancements

    • Support for GPU and state of the art adversarial training algorithm (PGD) on the robust_training.py example, by Fabian Pedregosa.
    • Update line search in LBFGS to use jit and unroll from LBFGS, by Ian Williamson.
    • Support dynamic maximum iteration count in iterative solvers, by Roy Frostig.
    • Fix tree_where for singleton pytrees, by Louis Béthune.
    • Remove QuadraticProg in projections and set init_params=None by default in QP solvers, by Louis Béthune.
    • Add missing 'value' attribute in LbfgsState, by Mathieu Blondel.

    Contributors

    Felipe Llinares, Fabian Pedregosa, Ian Williamson, Louis Bétune, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.3(Jan 31, 2022)

    New features

    • jaxopt.LBFGS
    • jaxopt.BacktrackingLineSearch
    • jaxopt.GaussNewton
    • jaxopt.NonlinearCG

    Bug fixes and enhancements

    • Support implicit AD in higher-order differentiation.

    Contributors

    Amir Saadat, Fabian Pedregosa, Geoffrey Négiar, Hyunsung Lee, Mathieu Blondel, Roy Frostig.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.2(Dec 18, 2021)

    New features

    • Quadratic programming solvers jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP
    • Iterative refinement

    New examples

    • Resnet example with Flax and JAXopt.

    Bug fixes and enhancements

    • Prevent recompilation of loops in solver.run if executing without jit.
    • Prevents recomputation of gradient in OptaxSolver.
    • Make solver.update jittable and ensure output states are consistent.
    • Allow Callable for the stepsize argument in jaxopt.ProximalGradient, jaxopt.ProjectedGradient and jaxopt.GradientDescent.

    Deprecated features

    • jaxopt.QuadraticProgramming is deprecated and will be removed in v0.3. Use jaxopt.CvxpyQP, jaxopt.OSQP, jaxopt.BoxOSQP and jaxopt.EqualityConstrainedQP instead.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Geoffrey Negiar, Louis Bethune, Mathieu Blondel, Vikas Sindhwani.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1.1(Oct 19, 2021)

    New features

    • Added solver jaxopt.ArmijoSGD
    • Added example Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.
    • Added example Comparison of different SGD algorithms.

    Bug fixes

    • Allow non-jittable proximity operators in jaxopt.ProximalGradient
    • Raise an exception if a quadratic program is infeasible or unbounded

    Contributors

    Fabian Pedregosa, Louis Bethune, Mathieu Blondel.

    Source code(tar.gz)
    Source code(zip)
  • jaxopt-v0.1(Oct 14, 2021)

    Classes

    • jaxopt.AndersonAcceleration
    • jaxopt.AndersonWrapper
    • jaxopt.Bisection
    • jaxopt.BlockCoordinateDescent
    • jaxopt.FixedPointIteration
    • jaxopt.GradientDescent
    • jaxopt.MirrorDescent
    • jaxopt.OptaxSolver
    • jaxopt.PolyakSGD
    • jaxopt.ProjectedGradient
    • jaxopt.ProximalGradient
    • jaxopt.QuadraticProgramming
    • jaxopt.ScipyBoundedLeastSquares
    • jaxopt.ScipyBoundedMinimize
    • jaxopt.ScipyLeastSquares
    • jaxopt.ScipyMinimize
    • jaxopt.ScipyRootFinding
    • Implicit differentiation

    Examples

    • Binary kernel SVM with intercept.
    • Image classification example with Flax and JAXopt.
    • Image classification example with Haiku and JAXopt.
    • VAE example with Haiku and JAXopt.
    • Implicit differentiation of lasso.
    • Multiclass linear SVM (without intercept).
    • Non-negative matrix factorizaton (NMF) using alternating minimization.
    • Dataset distillation.
    • Implicit differentiation of ridge regression.
    • Robust training.
    • Anderson acceleration of gradient descent.
    • Anderson acceleration of block coordinate descent.
    • Anderson acceleration in application to Picard–Lindelöf theorem.

    Contributors

    Fabian Pedregosa, Felipe Llinares, Robert Gower, Louis Bethune, Marco Cuturi, Mathieu Blondel, Peter Hawkins, Quentin Berthet, Roy Frostig, Ta-Chu Kao

    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
SOTA easy to use PyTorch-based DL training library

Easily train or fine-tune SOTA computer vision models from one training repository. SuperGradients Introduction Welcome to SuperGradients, a free open

619 Jan 03, 2023
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
FB-tCNN for SSVEP Recognition

FB-tCNN for SSVEP Recognition Here are the codes of the tCNN and FB-tCNN in the paper "Filter Bank Convolutional Neural Network for Short Time-Window

Wenlong Ding 12 Dec 14, 2022
StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Demo video: CVPR 2021 Oral: Single Channel Manipulation: Localized or attribu

Zongze Wu 267 Dec 30, 2022
以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的斗地主ai

ddz-ai 介绍 斗地主是一种扑克游戏。游戏最少由3个玩家进行,用一副54张牌(连鬼牌),其中一方为地主,其余两家为另一方,双方对战,先出完牌的一方获胜。 ddz-ai以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的系统,使其经过大量训练后,能在实际游戏中获

freefuiiismyname 88 May 15, 2022
A PyTorch-based library for fast prototyping and sharing of deep neural network models.

A PyTorch-based library for fast prototyping and sharing of deep neural network models.

78 Jan 03, 2023
PyTorch implementation of the paper Ultra Fast Structure-aware Deep Lane Detection

PyTorch implementation of the paper Ultra Fast Structure-aware Deep Lane Detection

1.4k Jan 06, 2023
PyTorch code accompanying the paper "Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning" (NeurIPS 2021).

HIGL This is a PyTorch implementation for our paper: Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning (NeurIPS 2021). Our cod

Junsu Kim 20 Dec 14, 2022
Unsupervised Image Generation with Infinite Generative Adversarial Networks

Unsupervised Image Generation with Infinite Generative Adversarial Networks Here is the implementation of MICGANs using DCGAN architecture on MNIST da

16 Dec 24, 2021
Official Implementation of CoSMo: Content-Style Modulation for Image Retrieval with Text Feedback

CoSMo.pytorch Official Implementation of CoSMo: Content-Style Modulation for Image Retrieval with Text Feedback, Seungmin Lee*, Dongwan Kim*, Bohyung

Seung Min Lee 54 Dec 08, 2022
Pytorch library for seismic data augmentation

Pytorch library for seismic data augmentation

Artemii Novoselov 27 Nov 22, 2022
This repository is maintained for the scientific paper tittled " Study of keyword extraction techniques for Electric Double Layer Capacitor domain using text similarity indexes: An experimental analysis "

kwd-extraction-study This repository is maintained for the scientific paper tittled " Study of keyword extraction techniques for Electric Double Layer

ping 543f 1 Dec 05, 2022
Allele-specific pipeline for unbiased read mapping(WIP), QTL discovery(WIP), and allelic-imbalance analysis

WASP2 (Currently in pre-development): Allele-specific pipeline for unbiased read mapping(WIP), QTL discovery(WIP), and allelic-imbalance analysis Requ

McVicker Lab 2 Aug 11, 2022
code for paper -- "Seamless Satellite-image Synthesis"

Seamless Satellite-image Synthesis by Jialin Zhu and Tom Kelly. Project site. The code of our models borrows heavily from the BicycleGAN repository an

Light 14 Apr 05, 2022
Language models are open knowledge graphs ( non official implementation )

language-models-are-knowledge-graphs-pytorch Language models are open knowledge graphs ( work in progress ) A non official reimplementation of Languag

theblackcat102 132 Dec 18, 2022
GANTheftAuto is a fork of the Nvidia's GameGAN

Description GANTheftAuto is a fork of the Nvidia's GameGAN, which is research focused on emulating dynamic game environments. The early research done

Harrison 801 Dec 27, 2022
Code for "ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on", accepted at WACV 2021 Generation of Human Behavior Workshop.

ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on [ Paper ] [ Project Page ] This repository contains the code fo

Andrew Jong 97 Dec 13, 2022
This is an official implementation for "SimMIM: A Simple Framework for Masked Image Modeling".

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

Microsoft 674 Dec 26, 2022
ManimML is a project focused on providing animations and visualizations of common machine learning concepts with the Manim Community Library.

ManimML ManimML is a project focused on providing animations and visualizations of common machine learning concepts with the Manim Community Library.

259 Jan 04, 2023
Chunkmogrify: Real image inversion via Segments

Chunkmogrify: Real image inversion via Segments Teaser video with live editing sessions can be found here This code demonstrates the ideas discussed i

David Futschik 112 Jan 04, 2023