Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

Overview

tests badge pypi badge docs badge license badge

coax

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

readthedocs

For the full documentation, including many examples, go to https://coax.readthedocs.io/

Install

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version. To install without CUDA, simply run:

$ pip install jaxlib jax coax --upgrade

If you do require CUDA support, please check out the Installation Guide.

Getting Started

Have a look at the Getting Started page to train your first RL agent.


Comments
  • Quantile Q-Learning Implementation

    Quantile Q-Learning Implementation

    This PR adds a QuantileQ class with function types 3 and 4 that accept a number of quantiles together with the state (and action), as well as a QuantileQLearning class. The QuantileQ function could be merged into the Q class which would simplify the user-facing API. However, some more work needs to be done to incorporate the QuantileQLearning class into the QLearning class. I just wanted to validate that this is the correct approach to take to implement the IQN.

    There still is some documentation for the quantile huber loss missing and the notebooks need to be added and tuned.

    Closes https://github.com/coax-dev/coax/issues/3

    opened by frederikschubert 11
  • Add DeepMind Control Suite Example

    Add DeepMind Control Suite Example

    This PR is a rework of https://github.com/coax-dev/coax/pull/26 and adds an example for using SAC on the Walker.walk task from the DeepMind Control Suite.

    Depends on https://github.com/coax-dev/coax/pull/27 and https://github.com/coax-dev/coax/pull/28

    opened by frederikschubert 6
  • Assertion assert_equal_shape failed for MultiDiscrete action space

    Assertion assert_equal_shape failed for MultiDiscrete action space

    First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

    So my questions are:

    1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
    2. Did I declare my policy function properly?
    3. Any general ideas on how to debug JAX functions?

    I would appreciate any feedback. Thank you!

    Error message

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    Input In [25], in <cell line: 5>()
         13     transition_batch = tracer.pop()
         14     Gn = transition_batch.Rn
    ---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
         16     env.record_metrics(metrics)
         17 if done:
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
        127 def update(self, transition_batch, Adv):
        128     r"""
        129 
        130     Update the model parameters (weights) of the underlying function approximator.
       (...)
        147 
        148     """
    --> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
        150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
        151         raise RuntimeError(f"found nan's in grads: {grads}")
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
        212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
        213     warnings.warn(
        214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
        215         "should be non-zero. Please sample actions with their propensities: "
        216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
        217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
    --> 218 return self._grad_and_metrics_func(
        219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
        220     transition_batch, Adv)
    
    File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
         58 def __call__(self, *args, **kwargs):
    ---> 59     return self._jitted_func(*args, **kwargs)
    
        [... skipping hidden 14 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
         77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
         78     grads_func = jax.grad(loss_func, has_aux=True)
         79     grads, (metrics, state_new) = \
    ---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
         82     # add some diagnostics of the gradients
         83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))
    
        [... skipping hidden 10 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
         45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
         46     objective, (dist_params, log_pi, state_new) = \
    ---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
         49     # flip sign to turn objective into loss
         50     loss = -objective
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
         49 W = jnp.clip(transition_batch.W, 0.1, 10.)
         51 # some consistency checks
    ---> 52 chex.assert_equal_shape([W, Adv, log_pi])
         53 chex.assert_rank([W, Adv, log_pi], 1)
         54 objective = W * Adv * log_pi
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
        195 else:
        196   try:
    --> 197     host_assertion(*args, **kwargs)
        198   except jax.errors.ConcretizationTypeError as exc:
        199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
        200            "likely that it tried to access tensors' values during tracing. "
        201            "Make sure that you defined a jittable version of this Chex "
        202            "assertion.")
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
        154     custom_message = custom_message.format(*custom_message_format_vars)
        155   error_msg = f"{error_msg} [{custom_message}]"
    --> 157 raise exception_type(error_msg)
    
    AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].
    

    Example data

    ExampleData(
      inputs=Inputs(
        args=ArgsType2(
          S={
            'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
          is_training=True)
        static_argnums=(
          1))
      output=(
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))
    

    Policy function

    def pi(S, is_training):
        module = CustomizedModule()
        res = tuple([{"logits": item} for item in module(S["features"])])
        return res
    
    question 
    opened by xiangyuy 5
  • 'linear/w' does not match shape

    'linear/w' does not match shape

    I've been starting to learn about RL and have been trying to get coax up and running, but have run into an issue that I'm not sure how to resolve. I'm doing Q-learning on a custom gym environment, and I can run the following pieces successfully:

    q = coax.Q(func_q, env)
    pi = coax.Policy(func_pi, env)
    
    qlearning = coax.td_learning.QLearning(q, pi_targ=pi, optimizer=optax.adam(0.001))
    cache = coax.reward_tracing.NStep(n=1, gamma=0.9)
    

    Additionally, my setup passes the simple checks of:

    data = coax.Q.example_data(env) # Looks good
    ...
    s = env.observation_space.sample()
    a = env.action_space.sample()
    print(q(s,a)) # 0.0
    ...
    a = pi(s)
    print(a) # [0, 0, 0, 0, 0] as I have a MultiDiscrete action space
    

    However, once I get to actually running the training loop:

    for ep in range(50):
      pi.epsilon = 0.1
      s = env.reset()
    
      for t in range(env.maxGuesses):
        a = pi(s)
        s_next, r, done, info = env.step(a)
    
        # update
        cache.add(s, a, r, done)
    
        while cache:
          transition_batch = cache.pop()
          metrics = qlearning.update(transition_batch)
          env.record_metrics(metrics)
    
        if done:
          break
    
        s = s_next
    
        # early stopping
        if env.avg_G > env.reward_threshold:
          break
    

    I get a bunch of errors with the most human-readable of them saying:

    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    By adjusting the parameters of the environment, I can adjust what the numbers that are mismatched are. I can't get them to match and either way that seems like the wrong solution as something more fundamental seems to be the issue.

    For reference, here are my functions for q and pi:

    def func_pi(S, is_training):
      logits = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(Wordle.wordLength*len(alphabet), w_init=jnp.zeros) # This many possible actions
      ))
      # First, convert to a vector:
      sVec = state_to_vec(S)
    
      # Now get the output:
      logitVec = logits(sVec)
    
      # Now chunk the output into alphabet-sized pieces (definitionally an integral
      # number of them). There will be Wordle.wordLength chunks of this length
      chunks = jnp.split(logitVec, Wordle.wordLength)
    
      # Now format our output array:
      ret = []
      for chunk in chunks:
        ret.append({'logits': jnp.reshape(chunk,(1,len(alphabet)))})
    
      return tuple(ret)
    
    # and for actual state:
    def func_q(S, A, is_training):
      value = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(30), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
      ))
    
      sVec = state_to_vec(S)
      aVec = action_to_vec(A)
    
      X = jnp.concatenate((sVec, aVec))
      return value(X)
    

    Note that state_to_vec(S) and action_to_vec(A) just convert from my internal types to jnp.array's for use with Haiku.

    I'm quite new to coax/JAX/Haiku so it's entirely possible I've set something up wrong. For completeness here's the full text of the error:

    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 596, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
        compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
        ans = call(fun, *args)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
        return lower_xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1623, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 996, in grad_f_aux
        (_, aux), g = value_and_grad_f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 1067, in value_and_grad_f
        ans, vjp_py, aux = _vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 2478, in _vjp
        out_primal, out_vjp, aux = ad.vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 118, in vjp
        out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 103, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 520, in trace_to_jaxpr
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 324, in process_call
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 204, in process_call
        jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 317, in partial_eval
        out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1364, in process_call
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    Please let me know if other information would be useful or relevant (or let me know if this isn't actually a coax issue...).

    Thanks for your help and the neat package.

    bug good first issue question 
    opened by bcerjan 5
  • DQN pong example doesn't work off the shelf

    DQN pong example doesn't work off the shelf

    Describe the bug

    Running the DQN example on pong generates the following error when generating a gif:

      File ".../lib/python3.9/site-packages/coax/utils/_misc.py", line 475, in generate_gif
        assert env.render_mode == 'rgb_array', "env.render_mode must be 'rgb_array'"
    

    This is likely due to some recent updates to gym. Currently, on gym==0.26.2 I observe the following:

    import gym
    env = gym.make('PongNoFrameskip-v4', render_mode="rgb_array")
    print(env.render_mode) # prints None
    
    opened by thisiscam 4
  • Add dm_control example for SAC

    Add dm_control example for SAC

    This PR introduces the common squashed normal distribution for the SAC policy on dm_control and provides an example that solves the walker.walk task. Interestingly clipping the actions to the range [-1, 1] diverges. rendering

    @KristianHolsheimer How would you go about changing the installation script for this notebook to add dm_control as a dependency?

    opened by frederikschubert 4
  • Frozen Lake example has an invalid gym signature.

    Frozen Lake example has an invalid gym signature.

    Describe the bug

    The example for Frozen Lake in the main branch of the docs isn't fully updated for the new version of gym's signature.

    ValueError Traceback (most recent call last) in 77 78 a = pi.mode(s) ---> 79 s, r, done, info = env.step(a) 80 81 env.render()

    ValueError: too many values to unpack (expected 4)

    Expected behavior

    Executing the notebook should not result in a ValueError.

    To Reproduce

    Colab notebook to repro the bug:

    - https://colab.research.google.com/...

    Runtime used for this colab notebook: ... (e.g. CPU/GPU/TPU)

    Any.

    Additional context

    Simple fix, happy to contribute a pull request.

    opened by dbleyl 3
  • Incorporating jax.jit into a customer policy

    Incorporating jax.jit into a customer policy

    I'm a bit new to JAX so my question might sound very naive. Suppose we are trying to solve a policy optimization problem through REINFORCE algorithm and suppose we already have our environment at hand (env). We define our customer policy as follows,

    class CustomPolicy(hk.Module):
        def __init__(self, name = None):
            super().__init__(name = name)
        
    
        def __call__(self, x):
            w = hk.get_parameter("w", shape= ... , dtype = x.dtype, init=jnp.zeros)
            # some computation
            return out
    

    Per the documentation, then we define

    def custom_policy(S, is_training=True):
        logits = CustomPolicy()
        return {'logits': logits(S)}
    

    and finally the policy is stated as follows,

    pi = coax.Policy(custom_policy, env)

    I was wondering is there any way to incorporate @jax.jit into this structure to further quicken the performance. Thanks.

    question 
    opened by UweGensheimer 3
  • Multi-Step Entropy Regularization for SAC

    Multi-Step Entropy Regularization for SAC

    • Add record_extra_info flag to the NStep tracer that records the intermediate states in the new extra_info field to TransitionBatch
    • Add support for the NStepEntropyRegularizer in SoftPG

    This PR contains an initial working implementation of the mechanism and sums um the discounted entropy bonuses of the states s_t, s_{t + 1}, ... , s_{t + n - 1} for the soft policy gradient regularization.

    opened by frederikschubert 3
  • Implementation of SAC

    Implementation of SAC

    Since SAC is really similar to TD3, we are able to re-use most of its components. The differences are:

    • The actions to update the q-functions and policy are sampled using the current policy (instead of taking the mode).
    • There is no target policy.
    • The log variance of the policy depends on the state.
    • The policy is entropy regularized.

    The current implementation does not support multi-step td-learning.

    opened by frederikschubert 3
  • AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    Hi, unsure if this is a due to coax or jax but I get this error when running the pendulum ppo example, dqn runs fine however.

    A similar error I found online recommended changing the version of jaxlib so I went to using the jaxlib version set out in the coax getting started guide but seemed to have no affect jax version = 0.2.13 jaxlib version = 0.1.65 + cuda111 coax version = 0.1.6

    question 
    opened by mmcaulif 3
  • Recurrent Experience Replay

    Recurrent Experience Replay

    Is your feature request related to a problem? Please describe.

    It seems that the implemented replay buffers only operate over transitions, with no ability to operate over entire sequences. This prevents the use of recurrent policies for tackling POMDPs.

    Describe the solution you'd like

    A SequenceReplayBuffer that returns contiguous episodes instead of shuffled transitions.

    Describe alternatives you've considered

    Additional context

    enhancement 
    opened by smorad 3
  • MiniMax Algorithm?

    MiniMax Algorithm?

    How would you implement a minimax q-learner with coax?

    Hi there! I love the package and how accessible it is to relative newbies. The tutorials are pretty great and the accompanying videos are very helpful!

    I was wondering what the best way to implement a minimax algorithm would be, would you recommend using two policies pi1 and pi2? Or is there something better suited for this?

    I'd like to re-implement something like this old blogpost of mine in coax to get a better feel of the library.

    Any help would be greatly appreciated :)

    question 
    opened by flaport 1
  • Convert Numpy Docstrings to Google Style

    Convert Numpy Docstrings to Google Style

    This issue tracks the progress of converting the numpy style docstrings to the more concise Google style.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    This depends on the type annotations https://github.com/coax-dev/coax/issues/13 for easier automatic conversions.

    enhancement 
    opened by frederikschubert 0
  • Add Type Annotations

    Add Type Annotations

    This issue tracks the progress of adding type annotations to coax.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    The types are added by utilising pyannotate and adding the following snippet to the coax._base.TestCase class:

    ...
    @classmethod
        def setUpClass(cls) -> None:
            collect_types.init_types_collection()
            collect_types.start()
    
        @classmethod
        def tearDownClass(cls) -> None:
            collect_types.stop()
            type_replacements = {
                "jaxlib.xla_extension.DeviceArray": "jax.numpy.ndarray",
                "haiku._src.data_structures.FlatMapping": "typing.Mapping",
                "coax._core.policy_test": "gym.Env"
            }
            types_str = collect_types.dumps_stats()
            for inferred_type, replacement in type_replacements.items():
                types_str = types_str.replace(inferred_type, replacement)
            with open(sys.modules[cls.__module__].__file__.replace(".py", "_types.json"), "w") as f:
                f.write(types_str)
    ...
    

    and the types are added automatically

    for t in coax/**/*_test_types.json
    do
        pyannotate --type-info $t -3 coax/* -w
    done
    
    enhancement 
    opened by frederikschubert 0
  • PPOClip grad update seems to cause inf update

    PPOClip grad update seems to cause inf update

    Describe the bug Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.

    During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.

    Expected behavior

    ...
    adv = np.random.rand(32)
    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
    print("grads", grads)
    print(ppo_clip._pi.params)
    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
    print(ppo_clip._pi.params)
    

    Results in:

    grads FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([ 0.0477 , -0.02505, -0.05048,  0.02798], dtype=float16),
                  'w': DeviceArray([[ 0.01338 , -0.01921 , -0.01038 ,  0.01622 ],
                                    [ 0.02406 , -0.01683 , -0.02039 ,  0.01316 ],
                                    [ 0.0332  , -0.0227  , -0.03108 ,  0.02061 ],
                                    ...,
                                    [ 0.02452 , -0.00956 , -0.01997 ,  0.005024],
                                    [ 0.010025,  0.001724, -0.03467 ,  0.02295 ],
                                    [ 0.01886 , -0.01413 , -0.01494 ,  0.01022 ]], dtype=float16),
                }),
    
    FlatMapping({
      'linear': FlatMapping({
                  'w': DeviceArray([[-1.0124e-02,  3.4389e-03,  2.9316e-03,  6.5498e-03],
                                    [ 3.3302e-03, -1.7233e-03, -3.0422e-03, -1.8060e-04],
                                    [-2.8908e-05, -3.3131e-03, -6.1073e-03,  6.5804e-03],
                                    ...,
                                    [-2.5597e-03,  7.3471e-03, -3.6221e-03, -5.6801e-03],
                                    [-7.3471e-03, -3.7746e-03,  5.8746e-03,  6.1531e-03],
                                    [-1.1940e-03,  6.9733e-03, -5.0507e-03,  3.4218e-03]],            dtype=float16),
                  'b': DeviceArray([0., 0., 0., 0.], dtype=float16),
                }),
    })
    
    FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([-0.001002,  0.000978,  0.001001, -0.001007], dtype=float16),
                  'w': DeviceArray([[-0.01111  ,  0.004448 ,  0.00386  ,  0.00551  ],
                                    [ 0.002354 , -0.0007563, -0.002048 , -0.001162 ],
                                    [-0.001021 , -0.002335 , -0.005104 ,  0.005558 ],
                                    ...,
                                    [-0.003561 ,  0.008224 , -0.002628 ,       -inf],
                                    [-0.00828  ,       -inf,  0.006874 ,  0.00515  ],
                                    [-0.002203 ,  0.00804  , -0.004086 ,  0.002493 ]],            dtype=float16),
                }),
    

    Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:

    import os
    from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
    from luxai2021.game.game import Game
    from luxai2021.game.actions import *
    from luxai2021.game.constants import LuxMatchConfigs_Default
    
    from luxai2021.env.agent import Agent, AgentWithTeamModel
    import numpy as np
    
    from agent import TeamAgent
    
    # set some env vars
    os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')     # tell JAX to use GPU
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet
    
    import gym
    import jax
    import coax
    import haiku as hk
    import jax.numpy as jnp
    from optax import adam
    
    
    # the name of this script
    name = 'ppo'
    
    configs = LuxMatchConfigs_Default
    
    player = TeamAgent(mode="train")
    opponent = Agent()
    
    env = LuxEnvironment(configs=configs,
                                    learning_agent=player,
                                    opponent_agent=opponent)
    env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
    
    def func_pi(S, is_training):
        n_actions = 4
        out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
        return out
    
    def func_v(S, is_training):
        h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
        return h
    
    '''
    def func_pi(S, is_training):
        #print(env.action_space.shape)
        n_filters = 5
        n_actions = 4
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))
        
        print('h', type(h), h.shape)
        h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
        h_head_actions = hk.Linear(n_actions)(h_head)
        print('h_head_actions', type(h_head_actions), h_head_actions.shape)
        #print(h_head_actions)
    
        out = {'logits': h_head_actions}
        
        return out
    
    def func_v(S, is_training):
        n_filters = 5
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))
    
        h = hk.Flatten()(h)
        h = jax.nn.relu(hk.Linear(64)(h))
        h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))
        
        return h
    '''
    
    
    # function approximators
    pi = coax.Policy(func_pi, env)
    v = coax.V(func_v, env)
    
    # target networks
    pi_behavior = pi.copy()
    v_targ = v.copy()
    
    # policy regularizer (avoid premature exploitation)
    entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)
    
    # updaters
    simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
    ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))
    
    # reward tracer and replay buffer
    tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
    buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)
    
    # run episodes
    max_episode_steps = 400
    while env.T < 3000000:
        s = env.reset()
    
        for t in range(max_episode_steps):
            print(t)
            a, logp = pi_behavior(s, return_logp=True)
            s_next, r, done, info = env.step(a)
    
            # trace rewards and add transition to replay buffer
            tracer.add(s, a, r, done, logp)
            while tracer:
                buffer.add(tracer.pop())
    
            # learn
            if len(buffer) >= buffer.capacity:
                num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
                for i in range(num_batches):
                    transition_batch = buffer.sample(32)
                    grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
                    metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
    
                    
                    adv = np.random.rand(32)
                    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
                    print("grads", grads)
                    print(ppo_clip._pi.params)
                    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
                    print(ppo_clip._pi.params)
                    exit()
                    env.record_metrics(metrics_pi)
                    env.record_metrics(metrics_v)
                    
    
                buffer.clear()
    
                # sync target networks
                pi_behavior.soft_update(pi, tau=0.1)
                v_targ.soft_update(v, tau=0.1)
    
            if done:
                break
    
            s = s_next
    
        # generate an animated GIF to see what's going on
        if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
            T = env.T - env.T % 10000  # round to 10000s
            coax.utils.generate_gif(
                env=env, policy=pi, resize_to=(320, 420),
                filepath=f"./data/gifs/{name}/T{T:08d}.gif")
    
    
    opened by glmcdona 3
Releases(v0.1.12)
Adjusting for Autocorrelated Errors in Neural Networks for Time Series

Adjusting for Autocorrelated Errors in Neural Networks for Time Series This repository is the official implementation of the paper "Adjusting for Auto

Fan-Keng Sun 51 Nov 05, 2022
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

DongGeun-Yoon 19 Jun 07, 2022
Implementation of popular bandit algorithms in batch environments.

batch-bandits Implementation of popular bandit algorithms in batch environments. Source code to our paper "The Impact of Batch Learning in Stochastic

Danil Provodin 2 Sep 11, 2022
A deep learning CNN model to identify and classify and check if a person is wearing a mask or not.

Face Mask Detection The Model is designed to check if any human is wearing a mask or not. Dataset Description The Dataset contains a total of 11,792 i

1 Mar 01, 2022
本步态识别系统主要基于GaitSet模型进行实现

本步态识别系统主要基于GaitSet模型进行实现。在尝试部署本系统之前,建立理解GaitSet模型的网络结构、训练和推理方法。 系统的实现效果如视频所示: 演示视频 由于模型较大,部分模型文件存储在百度云盘。 链接提取码:33mb 具体部署过程 1.下载代码 2.安装requirements.txt

16 Oct 22, 2022
Code for the submitted paper Surrogate-based cross-correlation for particle image velocimetry

Surrogate-based cross-correlation (SBCC) This repository contains code for the submitted paper Surrogate-based cross-correlation for particle image ve

5 Jun 30, 2022
PointPillars inference with TensorRT

A project demonstrating how to use CUDA-PointPillars to deal with cloud points data from lidar.

NVIDIA AI IOT 315 Dec 31, 2022
Constructing Neural Network-Based Models for Simulating Dynamical Systems

Constructing Neural Network-Based Models for Simulating Dynamical Systems Note this repo is work in progress prior to reviewing This is a companion re

Christian Møldrup Legaard 21 Nov 25, 2022
Scalable Optical Flow-based Image Montaging and Alignment

SOFIMA SOFIMA (Scalable Optical Flow-based Image Montaging and Alignment) is a tool for stitching, aligning and warping large 2d, 3d and 4d microscopy

Google Research 16 Dec 21, 2022
Code of the lileonardo team for the 2021 Emotion and Theme Recognition in Music task of MediaEval 2021

Emotion and Theme Recognition in Music The repository contains code for the submission of the lileonardo team to the 2021 Emotion and Theme Recognitio

Vincent Bour 8 Aug 02, 2022
Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR)

This is the official implementation of our paper Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR), which has been accepted by WSDM2022.

Yongchun Zhu 81 Dec 29, 2022
PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."

FullSubNet This Git repository for the official PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech E

郝翔 357 Jan 04, 2023
Official Implement of CVPR 2021 paper “Cross-Modal Collaborative Representation Learning and a Large-Scale RGBT Benchmark for Crowd Counting”

RGBT Crowd Counting Lingbo Liu, Jiaqi Chen, Hefeng Wu, Guanbin Li, Chenglong Li, Liang Lin. "Cross-Modal Collaborative Representation Learning and a L

37 Dec 08, 2022
Unofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization'

pytorch-AdaIN This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Hua

Naoto Inoue 873 Jan 06, 2023
SLIDE : In Defense of Smart Algorithms over Hardware Acceleration for Large-Scale Deep Learning Systems

The SLIDE package contains the source code for reproducing the main experiments in this paper. Dataset The Datasets can be downloaded in Amazon-

Intel Labs 72 Dec 16, 2022
Pytorch Lightning Implementation of SC-Depth Methods.

SC_Depth_pl: This is a pytorch lightning implementation of SC-Depth (V1, V2) for self-supervised learning of monocular depth from video. In the V1 (IJ

JiaWang Bian 216 Dec 30, 2022
Pipeline code for Sequential-GAM(Genome Architecture Mapping).

Sequential-GAM Pipeline code for Sequential-GAM(Genome Architecture Mapping). mapping whole_preprocess.sh include the whole processing of mapping. usa

3 Nov 03, 2022
Simulate genealogical trees and genomic sequence data using population genetic models

msprime msprime is a population genetics simulator based on tskit. Msprime can simulate random ancestral histories for a sample of individuals (consis

Tskit developers 150 Dec 14, 2022
Crab is a flexible, fast recommender engine for Python that integrates classic information filtering recommendation algorithms in the world of scientific Python packages (numpy, scipy, matplotlib).

Crab - A Recommendation Engine library for Python Crab is a flexible, fast recommender engine for Python that integrates classic information filtering r

python-recsys 1.2k Dec 21, 2022
AttGAN: Facial Attribute Editing by Only Changing What You Want (IEEE TIP 2019)

News 11 Jan 2020: We clean up the code to make it more readable! The old version is here: v1. AttGAN TIP Nov. 2019, arXiv Nov. 2017 TensorFlow impleme

Zhenliang He 568 Dec 14, 2022