Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

Overview

tests badge pypi badge docs badge license badge

coax

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

readthedocs

For the full documentation, including many examples, go to https://coax.readthedocs.io/

Install

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version. To install without CUDA, simply run:

$ pip install jaxlib jax coax --upgrade

If you do require CUDA support, please check out the Installation Guide.

Getting Started

Have a look at the Getting Started page to train your first RL agent.


Comments
  • Quantile Q-Learning Implementation

    Quantile Q-Learning Implementation

    This PR adds a QuantileQ class with function types 3 and 4 that accept a number of quantiles together with the state (and action), as well as a QuantileQLearning class. The QuantileQ function could be merged into the Q class which would simplify the user-facing API. However, some more work needs to be done to incorporate the QuantileQLearning class into the QLearning class. I just wanted to validate that this is the correct approach to take to implement the IQN.

    There still is some documentation for the quantile huber loss missing and the notebooks need to be added and tuned.

    Closes https://github.com/coax-dev/coax/issues/3

    opened by frederikschubert 11
  • Add DeepMind Control Suite Example

    Add DeepMind Control Suite Example

    This PR is a rework of https://github.com/coax-dev/coax/pull/26 and adds an example for using SAC on the Walker.walk task from the DeepMind Control Suite.

    Depends on https://github.com/coax-dev/coax/pull/27 and https://github.com/coax-dev/coax/pull/28

    opened by frederikschubert 6
  • Assertion assert_equal_shape failed for MultiDiscrete action space

    Assertion assert_equal_shape failed for MultiDiscrete action space

    First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

    So my questions are:

    1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
    2. Did I declare my policy function properly?
    3. Any general ideas on how to debug JAX functions?

    I would appreciate any feedback. Thank you!

    Error message

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    Input In [25], in <cell line: 5>()
         13     transition_batch = tracer.pop()
         14     Gn = transition_batch.Rn
    ---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
         16     env.record_metrics(metrics)
         17 if done:
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
        127 def update(self, transition_batch, Adv):
        128     r"""
        129 
        130     Update the model parameters (weights) of the underlying function approximator.
       (...)
        147 
        148     """
    --> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
        150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
        151         raise RuntimeError(f"found nan's in grads: {grads}")
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
        212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
        213     warnings.warn(
        214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
        215         "should be non-zero. Please sample actions with their propensities: "
        216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
        217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
    --> 218 return self._grad_and_metrics_func(
        219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
        220     transition_batch, Adv)
    
    File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
         58 def __call__(self, *args, **kwargs):
    ---> 59     return self._jitted_func(*args, **kwargs)
    
        [... skipping hidden 14 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
         77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
         78     grads_func = jax.grad(loss_func, has_aux=True)
         79     grads, (metrics, state_new) = \
    ---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
         82     # add some diagnostics of the gradients
         83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))
    
        [... skipping hidden 10 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
         45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
         46     objective, (dist_params, log_pi, state_new) = \
    ---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
         49     # flip sign to turn objective into loss
         50     loss = -objective
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
         49 W = jnp.clip(transition_batch.W, 0.1, 10.)
         51 # some consistency checks
    ---> 52 chex.assert_equal_shape([W, Adv, log_pi])
         53 chex.assert_rank([W, Adv, log_pi], 1)
         54 objective = W * Adv * log_pi
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
        195 else:
        196   try:
    --> 197     host_assertion(*args, **kwargs)
        198   except jax.errors.ConcretizationTypeError as exc:
        199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
        200            "likely that it tried to access tensors' values during tracing. "
        201            "Make sure that you defined a jittable version of this Chex "
        202            "assertion.")
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
        154     custom_message = custom_message.format(*custom_message_format_vars)
        155   error_msg = f"{error_msg} [{custom_message}]"
    --> 157 raise exception_type(error_msg)
    
    AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].
    

    Example data

    ExampleData(
      inputs=Inputs(
        args=ArgsType2(
          S={
            'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
          is_training=True)
        static_argnums=(
          1))
      output=(
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))
    

    Policy function

    def pi(S, is_training):
        module = CustomizedModule()
        res = tuple([{"logits": item} for item in module(S["features"])])
        return res
    
    question 
    opened by xiangyuy 5
  • 'linear/w' does not match shape

    'linear/w' does not match shape

    I've been starting to learn about RL and have been trying to get coax up and running, but have run into an issue that I'm not sure how to resolve. I'm doing Q-learning on a custom gym environment, and I can run the following pieces successfully:

    q = coax.Q(func_q, env)
    pi = coax.Policy(func_pi, env)
    
    qlearning = coax.td_learning.QLearning(q, pi_targ=pi, optimizer=optax.adam(0.001))
    cache = coax.reward_tracing.NStep(n=1, gamma=0.9)
    

    Additionally, my setup passes the simple checks of:

    data = coax.Q.example_data(env) # Looks good
    ...
    s = env.observation_space.sample()
    a = env.action_space.sample()
    print(q(s,a)) # 0.0
    ...
    a = pi(s)
    print(a) # [0, 0, 0, 0, 0] as I have a MultiDiscrete action space
    

    However, once I get to actually running the training loop:

    for ep in range(50):
      pi.epsilon = 0.1
      s = env.reset()
    
      for t in range(env.maxGuesses):
        a = pi(s)
        s_next, r, done, info = env.step(a)
    
        # update
        cache.add(s, a, r, done)
    
        while cache:
          transition_batch = cache.pop()
          metrics = qlearning.update(transition_batch)
          env.record_metrics(metrics)
    
        if done:
          break
    
        s = s_next
    
        # early stopping
        if env.avg_G > env.reward_threshold:
          break
    

    I get a bunch of errors with the most human-readable of them saying:

    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    By adjusting the parameters of the environment, I can adjust what the numbers that are mismatched are. I can't get them to match and either way that seems like the wrong solution as something more fundamental seems to be the issue.

    For reference, here are my functions for q and pi:

    def func_pi(S, is_training):
      logits = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(Wordle.wordLength*len(alphabet), w_init=jnp.zeros) # This many possible actions
      ))
      # First, convert to a vector:
      sVec = state_to_vec(S)
    
      # Now get the output:
      logitVec = logits(sVec)
    
      # Now chunk the output into alphabet-sized pieces (definitionally an integral
      # number of them). There will be Wordle.wordLength chunks of this length
      chunks = jnp.split(logitVec, Wordle.wordLength)
    
      # Now format our output array:
      ret = []
      for chunk in chunks:
        ret.append({'logits': jnp.reshape(chunk,(1,len(alphabet)))})
    
      return tuple(ret)
    
    # and for actual state:
    def func_q(S, A, is_training):
      value = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(30), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
      ))
    
      sVec = state_to_vec(S)
      aVec = action_to_vec(A)
    
      X = jnp.concatenate((sVec, aVec))
      return value(X)
    

    Note that state_to_vec(S) and action_to_vec(A) just convert from my internal types to jnp.array's for use with Haiku.

    I'm quite new to coax/JAX/Haiku so it's entirely possible I've set something up wrong. For completeness here's the full text of the error:

    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 596, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
        compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
        ans = call(fun, *args)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
        return lower_xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1623, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 996, in grad_f_aux
        (_, aux), g = value_and_grad_f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 1067, in value_and_grad_f
        ans, vjp_py, aux = _vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 2478, in _vjp
        out_primal, out_vjp, aux = ad.vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 118, in vjp
        out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 103, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 520, in trace_to_jaxpr
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 324, in process_call
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 204, in process_call
        jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 317, in partial_eval
        out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1364, in process_call
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    Please let me know if other information would be useful or relevant (or let me know if this isn't actually a coax issue...).

    Thanks for your help and the neat package.

    bug good first issue question 
    opened by bcerjan 5
  • DQN pong example doesn't work off the shelf

    DQN pong example doesn't work off the shelf

    Describe the bug

    Running the DQN example on pong generates the following error when generating a gif:

      File ".../lib/python3.9/site-packages/coax/utils/_misc.py", line 475, in generate_gif
        assert env.render_mode == 'rgb_array', "env.render_mode must be 'rgb_array'"
    

    This is likely due to some recent updates to gym. Currently, on gym==0.26.2 I observe the following:

    import gym
    env = gym.make('PongNoFrameskip-v4', render_mode="rgb_array")
    print(env.render_mode) # prints None
    
    opened by thisiscam 4
  • Add dm_control example for SAC

    Add dm_control example for SAC

    This PR introduces the common squashed normal distribution for the SAC policy on dm_control and provides an example that solves the walker.walk task. Interestingly clipping the actions to the range [-1, 1] diverges. rendering

    @KristianHolsheimer How would you go about changing the installation script for this notebook to add dm_control as a dependency?

    opened by frederikschubert 4
  • Frozen Lake example has an invalid gym signature.

    Frozen Lake example has an invalid gym signature.

    Describe the bug

    The example for Frozen Lake in the main branch of the docs isn't fully updated for the new version of gym's signature.

    ValueError Traceback (most recent call last) in 77 78 a = pi.mode(s) ---> 79 s, r, done, info = env.step(a) 80 81 env.render()

    ValueError: too many values to unpack (expected 4)

    Expected behavior

    Executing the notebook should not result in a ValueError.

    To Reproduce

    Colab notebook to repro the bug:

    - https://colab.research.google.com/...

    Runtime used for this colab notebook: ... (e.g. CPU/GPU/TPU)

    Any.

    Additional context

    Simple fix, happy to contribute a pull request.

    opened by dbleyl 3
  • Incorporating jax.jit into a customer policy

    Incorporating jax.jit into a customer policy

    I'm a bit new to JAX so my question might sound very naive. Suppose we are trying to solve a policy optimization problem through REINFORCE algorithm and suppose we already have our environment at hand (env). We define our customer policy as follows,

    class CustomPolicy(hk.Module):
        def __init__(self, name = None):
            super().__init__(name = name)
        
    
        def __call__(self, x):
            w = hk.get_parameter("w", shape= ... , dtype = x.dtype, init=jnp.zeros)
            # some computation
            return out
    

    Per the documentation, then we define

    def custom_policy(S, is_training=True):
        logits = CustomPolicy()
        return {'logits': logits(S)}
    

    and finally the policy is stated as follows,

    pi = coax.Policy(custom_policy, env)

    I was wondering is there any way to incorporate @jax.jit into this structure to further quicken the performance. Thanks.

    question 
    opened by UweGensheimer 3
  • Multi-Step Entropy Regularization for SAC

    Multi-Step Entropy Regularization for SAC

    • Add record_extra_info flag to the NStep tracer that records the intermediate states in the new extra_info field to TransitionBatch
    • Add support for the NStepEntropyRegularizer in SoftPG

    This PR contains an initial working implementation of the mechanism and sums um the discounted entropy bonuses of the states s_t, s_{t + 1}, ... , s_{t + n - 1} for the soft policy gradient regularization.

    opened by frederikschubert 3
  • Implementation of SAC

    Implementation of SAC

    Since SAC is really similar to TD3, we are able to re-use most of its components. The differences are:

    • The actions to update the q-functions and policy are sampled using the current policy (instead of taking the mode).
    • There is no target policy.
    • The log variance of the policy depends on the state.
    • The policy is entropy regularized.

    The current implementation does not support multi-step td-learning.

    opened by frederikschubert 3
  • AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    Hi, unsure if this is a due to coax or jax but I get this error when running the pendulum ppo example, dqn runs fine however.

    A similar error I found online recommended changing the version of jaxlib so I went to using the jaxlib version set out in the coax getting started guide but seemed to have no affect jax version = 0.2.13 jaxlib version = 0.1.65 + cuda111 coax version = 0.1.6

    question 
    opened by mmcaulif 3
  • Recurrent Experience Replay

    Recurrent Experience Replay

    Is your feature request related to a problem? Please describe.

    It seems that the implemented replay buffers only operate over transitions, with no ability to operate over entire sequences. This prevents the use of recurrent policies for tackling POMDPs.

    Describe the solution you'd like

    A SequenceReplayBuffer that returns contiguous episodes instead of shuffled transitions.

    Describe alternatives you've considered

    Additional context

    enhancement 
    opened by smorad 3
  • MiniMax Algorithm?

    MiniMax Algorithm?

    How would you implement a minimax q-learner with coax?

    Hi there! I love the package and how accessible it is to relative newbies. The tutorials are pretty great and the accompanying videos are very helpful!

    I was wondering what the best way to implement a minimax algorithm would be, would you recommend using two policies pi1 and pi2? Or is there something better suited for this?

    I'd like to re-implement something like this old blogpost of mine in coax to get a better feel of the library.

    Any help would be greatly appreciated :)

    question 
    opened by flaport 1
  • Convert Numpy Docstrings to Google Style

    Convert Numpy Docstrings to Google Style

    This issue tracks the progress of converting the numpy style docstrings to the more concise Google style.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    This depends on the type annotations https://github.com/coax-dev/coax/issues/13 for easier automatic conversions.

    enhancement 
    opened by frederikschubert 0
  • Add Type Annotations

    Add Type Annotations

    This issue tracks the progress of adding type annotations to coax.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    The types are added by utilising pyannotate and adding the following snippet to the coax._base.TestCase class:

    ...
    @classmethod
        def setUpClass(cls) -> None:
            collect_types.init_types_collection()
            collect_types.start()
    
        @classmethod
        def tearDownClass(cls) -> None:
            collect_types.stop()
            type_replacements = {
                "jaxlib.xla_extension.DeviceArray": "jax.numpy.ndarray",
                "haiku._src.data_structures.FlatMapping": "typing.Mapping",
                "coax._core.policy_test": "gym.Env"
            }
            types_str = collect_types.dumps_stats()
            for inferred_type, replacement in type_replacements.items():
                types_str = types_str.replace(inferred_type, replacement)
            with open(sys.modules[cls.__module__].__file__.replace(".py", "_types.json"), "w") as f:
                f.write(types_str)
    ...
    

    and the types are added automatically

    for t in coax/**/*_test_types.json
    do
        pyannotate --type-info $t -3 coax/* -w
    done
    
    enhancement 
    opened by frederikschubert 0
  • PPOClip grad update seems to cause inf update

    PPOClip grad update seems to cause inf update

    Describe the bug Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.

    During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.

    Expected behavior

    ...
    adv = np.random.rand(32)
    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
    print("grads", grads)
    print(ppo_clip._pi.params)
    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
    print(ppo_clip._pi.params)
    

    Results in:

    grads FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([ 0.0477 , -0.02505, -0.05048,  0.02798], dtype=float16),
                  'w': DeviceArray([[ 0.01338 , -0.01921 , -0.01038 ,  0.01622 ],
                                    [ 0.02406 , -0.01683 , -0.02039 ,  0.01316 ],
                                    [ 0.0332  , -0.0227  , -0.03108 ,  0.02061 ],
                                    ...,
                                    [ 0.02452 , -0.00956 , -0.01997 ,  0.005024],
                                    [ 0.010025,  0.001724, -0.03467 ,  0.02295 ],
                                    [ 0.01886 , -0.01413 , -0.01494 ,  0.01022 ]], dtype=float16),
                }),
    
    FlatMapping({
      'linear': FlatMapping({
                  'w': DeviceArray([[-1.0124e-02,  3.4389e-03,  2.9316e-03,  6.5498e-03],
                                    [ 3.3302e-03, -1.7233e-03, -3.0422e-03, -1.8060e-04],
                                    [-2.8908e-05, -3.3131e-03, -6.1073e-03,  6.5804e-03],
                                    ...,
                                    [-2.5597e-03,  7.3471e-03, -3.6221e-03, -5.6801e-03],
                                    [-7.3471e-03, -3.7746e-03,  5.8746e-03,  6.1531e-03],
                                    [-1.1940e-03,  6.9733e-03, -5.0507e-03,  3.4218e-03]],            dtype=float16),
                  'b': DeviceArray([0., 0., 0., 0.], dtype=float16),
                }),
    })
    
    FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([-0.001002,  0.000978,  0.001001, -0.001007], dtype=float16),
                  'w': DeviceArray([[-0.01111  ,  0.004448 ,  0.00386  ,  0.00551  ],
                                    [ 0.002354 , -0.0007563, -0.002048 , -0.001162 ],
                                    [-0.001021 , -0.002335 , -0.005104 ,  0.005558 ],
                                    ...,
                                    [-0.003561 ,  0.008224 , -0.002628 ,       -inf],
                                    [-0.00828  ,       -inf,  0.006874 ,  0.00515  ],
                                    [-0.002203 ,  0.00804  , -0.004086 ,  0.002493 ]],            dtype=float16),
                }),
    

    Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:

    import os
    from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
    from luxai2021.game.game import Game
    from luxai2021.game.actions import *
    from luxai2021.game.constants import LuxMatchConfigs_Default
    
    from luxai2021.env.agent import Agent, AgentWithTeamModel
    import numpy as np
    
    from agent import TeamAgent
    
    # set some env vars
    os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')     # tell JAX to use GPU
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet
    
    import gym
    import jax
    import coax
    import haiku as hk
    import jax.numpy as jnp
    from optax import adam
    
    
    # the name of this script
    name = 'ppo'
    
    configs = LuxMatchConfigs_Default
    
    player = TeamAgent(mode="train")
    opponent = Agent()
    
    env = LuxEnvironment(configs=configs,
                                    learning_agent=player,
                                    opponent_agent=opponent)
    env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
    
    def func_pi(S, is_training):
        n_actions = 4
        out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
        return out
    
    def func_v(S, is_training):
        h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
        return h
    
    '''
    def func_pi(S, is_training):
        #print(env.action_space.shape)
        n_filters = 5
        n_actions = 4
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))
        
        print('h', type(h), h.shape)
        h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
        h_head_actions = hk.Linear(n_actions)(h_head)
        print('h_head_actions', type(h_head_actions), h_head_actions.shape)
        #print(h_head_actions)
    
        out = {'logits': h_head_actions}
        
        return out
    
    def func_v(S, is_training):
        n_filters = 5
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))
    
        h = hk.Flatten()(h)
        h = jax.nn.relu(hk.Linear(64)(h))
        h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))
        
        return h
    '''
    
    
    # function approximators
    pi = coax.Policy(func_pi, env)
    v = coax.V(func_v, env)
    
    # target networks
    pi_behavior = pi.copy()
    v_targ = v.copy()
    
    # policy regularizer (avoid premature exploitation)
    entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)
    
    # updaters
    simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
    ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))
    
    # reward tracer and replay buffer
    tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
    buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)
    
    # run episodes
    max_episode_steps = 400
    while env.T < 3000000:
        s = env.reset()
    
        for t in range(max_episode_steps):
            print(t)
            a, logp = pi_behavior(s, return_logp=True)
            s_next, r, done, info = env.step(a)
    
            # trace rewards and add transition to replay buffer
            tracer.add(s, a, r, done, logp)
            while tracer:
                buffer.add(tracer.pop())
    
            # learn
            if len(buffer) >= buffer.capacity:
                num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
                for i in range(num_batches):
                    transition_batch = buffer.sample(32)
                    grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
                    metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
    
                    
                    adv = np.random.rand(32)
                    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
                    print("grads", grads)
                    print(ppo_clip._pi.params)
                    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
                    print(ppo_clip._pi.params)
                    exit()
                    env.record_metrics(metrics_pi)
                    env.record_metrics(metrics_v)
                    
    
                buffer.clear()
    
                # sync target networks
                pi_behavior.soft_update(pi, tau=0.1)
                v_targ.soft_update(v, tau=0.1)
    
            if done:
                break
    
            s = s_next
    
        # generate an animated GIF to see what's going on
        if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
            T = env.T - env.T % 10000  # round to 10000s
            coax.utils.generate_gif(
                env=env, policy=pi, resize_to=(320, 420),
                filepath=f"./data/gifs/{name}/T{T:08d}.gif")
    
    
    opened by glmcdona 3
Releases(v0.1.12)
AsymmetricGAN - Dual Generator Generative Adversarial Networks for Multi-Domain Image-to-Image Translation

AsymmetricGAN for Image-to-Image Translation AsymmetricGAN Framework for Multi-Domain Image-to-Image Translation AsymmetricGAN Framework for Hand Gest

Hao Tang 42 Jan 15, 2022
PyTorch Implement of Context Encoders: Feature Learning by Inpainting

Context Encoders: Feature Learning by Inpainting This is the Pytorch implement of CVPR 2016 paper on Context Encoders 1) Semantic Inpainting Demo Inst

321 Dec 25, 2022
Fast methods to work with hydro- and topography data in pure Python.

PyFlwDir Intro PyFlwDir contains a series of methods to work with gridded DEM and flow direction datasets, which are key to many workflows in many ear

Deltares 27 Dec 07, 2022
Context-Aware Image Matting for Simultaneous Foreground and Alpha Estimation

Context-Aware Image Matting for Simultaneous Foreground and Alpha Estimation This is the inference codes of Context-Aware Image Matting for Simultaneo

Qiqi Hou 125 Oct 22, 2022
Code + pre-trained models for the paper Keeping Your Eye on the Ball Trajectory Attention in Video Transformers

Motionformer This is an official pytorch implementation of paper Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers. In this rep

Facebook Research 192 Dec 23, 2022
Process text, including tokenizing and representing sentences as vectors and Applying some concepts like RNN, LSTM and GRU to create a classifier can detect the language in which a sentence is written from among 17 languages.

Language Identifier What is this ? The goal of this project is to create a model that is able to predict a given sentence language through text proces

Hossam Asaad 9 Dec 15, 2022
CausalNLP is a practical toolkit for causal inference with text as treatment, outcome, or "controlled-for" variable.

CausalNLP CausalNLP is a practical toolkit for causal inference with text as treatment, outcome, or "controlled-for" variable. Install pip install -U

Arun S. Maiya 95 Jan 03, 2023
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion" Coming soon, as soon as I finish a

Ziyao Zeng 14 Feb 26, 2022
Using fully convolutional networks for semantic segmentation with caffe for the cityscapes dataset

Using fully convolutional networks for semantic segmentation (Shelhamer et al.) with caffe for the cityscapes dataset How to get started Download the

Simon Guist 27 Jun 06, 2022
Codes for building and training the neural network model described in Domain-informed neural networks for interaction localization within astroparticle experiments.

Domain-informed Neural Networks Codes for building and training the neural network model described in Domain-informed neural networks for interaction

DIDACTS 0 Dec 13, 2021
Real time Human Detection Counting

In this python project, we are going to build the Human Detection and Counting System through Webcam or you can give your own video or images. This is a deep learning project on computer vision, whic

Mir Nawaz Ahmad 2 Jun 17, 2022
IGCN : Image-to-graph convolutional network

IGCN : Image-to-graph convolutional network IGCN is a learning framework for 2D/3D deformable model registration and alignment, and shape reconstructi

Megumi Nakao 7 Oct 27, 2022
Code for the paper: On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations

Non-Parametric Prior Actor-Critic (N-PPAC) This repository contains the code for On Pathologies in KL-Regularized Reinforcement Learning from Expert D

Cong Lu 5 May 13, 2022
ML-PersonalWork - Big assignment PersonalWork in Machine Learning, 2021 autumn BUAA.

ML-PersonalWork - Big assignment PersonalWork in Machine Learning, 2021 autumn BUAA.

Snapdragon Lee 2 Dec 16, 2022
PyTorch-based framework for Deep Hedging

PFHedge: Deep Hedging in PyTorch PFHedge is a PyTorch-based framework for Deep Hedging. PFHedge Documentation Neural Network Architecture for Efficien

139 Dec 30, 2022
Code for "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clouds", CVPR 2021

PV-RAFT This repository contains the PyTorch implementation for paper "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clou

Yi Wei 43 Dec 05, 2022
Record radiologists' eye gaze when they are labeling images.

Record radiologists' eye gaze when they are labeling images. Read for installation, usage, and deep learning examples. Why use MicEye Versatile As a l

24 Nov 03, 2022
An Open-Source Package for Information Retrieval.

OpenMatch An Open-Source Package for Information Retrieval. 😃 What's New Top Spot on TREC-COVID Challenge (May 2020, Round2) The twin goals of the ch

THUNLP 439 Dec 27, 2022
Code for our paper Domain Adaptive Semantic Segmentation with Self-Supervised Depth Estimation

CorDA Code for our paper Domain Adaptive Semantic Segmentation with Self-Supervised Depth Estimation Prerequisite Please create and activate the follo

Qin Wang 60 Nov 30, 2022
Attention over nodes in Graph Neural Networks using PyTorch (NeurIPS 2019)

Intro This repository contains code to generate data and reproduce experiments from our NeurIPS 2019 paper: Boris Knyazev, Graham W. Taylor, Mohamed R

Boris Knyazev 242 Jan 06, 2023