Deep reinforcement learning library built on top of Neural Network Libraries

Overview

License Build status

Deep Reinforcement Learning Library built on top of Neural Network Libraries

NNablaRL is a deep reinforcement learning library built on top of Neural Network Libraries that is intended to be used for research, development and production.

Installation

Installing NNablaRL is easy!

$ pip install nnabla-rl

NNablaRL only supports Python version >= 3.6 and NNabla version >= 1.17.

Enabling GPU accelaration (Optional)

NNablaRL algorithms run on CPU by default. To run the algorithm on GPU, first install nnabla-ext-cuda as follows. (Replace [cuda-version] depending on the CUDA version installed on your machine.)

$ pip install nnabla-ext-cuda[cuda-version]
# Example installation. Supposing CUDA 11.0 is installed on your machine.
$ pip install nnabla-ext-cuda110

After installing nnabla-ext-cuda, set the gpu id to run the algorithm on through algorithm's configuration.

import nnabla_rl.algorithms as A

config = A.DQNConfig(gpu_id=0) # Use gpu 0. If negative, will run on CPU.
dqn = A.DQN(env, config=config)
...

Features

Friendly API

NNablaRL has friendly Python APIs which enables to start training with only 3 lines of python code.

import nnabla_rl
import nnabla_rl.algorithms as A
from nnabla_rl.utils.reproductions import build_atari_env

env = build_atari_env("BreakoutNoFrameskip-v4") # 1
dqn = A.DQN(env)  # 2
dqn.train(env)  # 3

To get more details about NNablaRL, see documentation and examples.

Many builtin algorithms

Most of famous/SOTA deep reinforcement learning algorithms, such as DQN, SAC, BCQ, GAIL, etc., are implemented in NNablaRL. Implemented algorithms are carefully tested and evaluated. You can easily start training your agent using these verified implementations.

For the list of implemented algorithms see here.

You can also find the reproduction and evaluation results of each algorithm here.
Note that you may not get completely the same results when running the reproduction code on your computer. The result may slightly change depending on your machine, nnabla/nnabla-rl's package version, etc.

Seemless switching of online and offline training

In reinforcement learning, there are two main training procedures, online and offline, to train the agent. Online training is a training procedure that executes both data collection and network update alternately. Conversely, offline training is a training procedure that updates the network using only existing data. With NNablaRL, you can switch these two training procedures seemlessly. For example, as shown below, you can easily train a robot's controller online using simulated environment and finetune it offline with real robot dataset.

import nnabla_rl
import nnabla_rl.algorithms as A

simulator = get_simulator() # This is just an example. Assuming that simulator exists
dqn = A.DQN(simulator)
# train online for 1M iterations
dqn.train_online(simulator, total_iterations=1000000)

real_data = get_real_robot_data() # This is also an example. Assuming that you have real robot data
# fine tune the agent offline for 10k iterations using real data
dqn.train_offline(real_data, total_iterations=10000)

Getting started

Try below interactive demos to get started.
You can run it directly on Colab from the links in the table below.

Title Notebook Target RL task
Simple reinforcement learning training to get started Open In Colab Pendulum
Learn how to use training algorithms Open In Colab Pendulum
Learn how to use customized network model for training Open In Colab Mountain car
Learn how to use different network solver for training Open In Colab Pendulum
Learn how to use different replay buffer for training Open In Colab Pendulum
Learn how to use your own environment for training Open In Colab Customized environment
Atari game training example Open In Colab Atari games

Documentation

Full documentation is here.

Contribution guide

Any kind of contribution to NNablaRL is welcome! See the contribution guide for details.

License

NNablaRL is provided under the Apache License Version 2.0 license.

Comments
  • Update cem function interface

    Update cem function interface

    Updated interface of cross entropy function methods. The args, pop_size is now changed to sample_size. In addition, the given objective function to CEM function will be called with variable x which has (batch_size, sample_size, x_dim). This is different from previous interface. If you want to know the details, please see the function docs.

    opened by sbsekiguchi 1
  • Add implementation for RNN support and DRQN algorithm

    Add implementation for RNN support and DRQN algorithm

    Add RNN model support and DRQN algorithm.

    Following trainers will support RNN-model.

    • Q value-based trainers
    • Deterministic gradient and Soft policy trainers

    Other trainers can support RNN models in future but is not implemented in the initial release.

    See this paper for the details of the DRQN algorithm.

    opened by ishihara-y 1
  • Implement SACD

    Implement SACD

    This PR implements SAC-D algorithm. https://arxiv.org/abs/2206.13901

    These changes have been made:

    • New environments with factored reward functions have been added
      • FactoredLunarLanderContinuousV2NNablaRL-v1
      • FactoredAntV4NNablaRL-v1
      • FactoredHopperV4NNablaRL-v1
      • FactoredHalfCheetahV4NNablaRL-v1
      • FactoredWalker2dV4NNablaRL-v1
      • FactoredHumanoidV4NNablaRL-v1
    • SACD algorithms has been added
    • SoftQDTrainer has been added
    • _InfluenceMetricsEvaluator has been added
    • reproduction script has been added (not benchmarked yet)

    visualizing influence metrics

    import gym
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    import nnabla_rl.algorithms as A
    import nnabla_rl.hooks as H
    import nnabla_rl.writers as W
    from nnabla_rl.utils.evaluator import EpisodicEvaluator
    
    env = gym.make("FactoredLunarLanderContinuousV2NNablaRL-v1")
    eval_env = gym.make("FactoredLunarLanderContinuousV2NNablaRL-v1")
    
    evaluation_hook = H.EvaluationHook(
        eval_env,
        EpisodicEvaluator(run_per_evaluation=10),
        timing=5000,
        writer=W.FileWriter(outdir="logdir", file_prefix='evaluation_result'),
    )
    iteration_num_hook = H.IterationNumHook(timing=100)
    
    config = A.SACDConfig(gpu_id=0, reward_dimension=9)
    sacd = A.SACD(env, config=config)
    sacd.set_hooks([iteration_num_hook, evaluation_hook])
    sacd.train_online(env, total_iterations=100000)
    
    influence_history = []
    
    state = env.reset()
    while True:
        action = sacd.compute_eval_action(state)
        influence = sacd.compute_influence_metrics(state, action)
        influence_history.append(influence)
        state, _, done, _ = env.step(action)
        if done:
            break
    
    influence_history = np.array(influence_history)
    for i, label in enumerate(["position", "velocity", "angle", "left_leg", "right_leg", "main_eingine", "side_engine", "failure", "success"]):
        plt.plot(influence_history[:, i], label=label)
    plt.xlabel("step")
    plt.ylabel("influence metrics")
    plt.legend()
    plt.show()
    

    image

    sample animation

    sample

    opened by ishihara-y 0
  • Add gmm and Update gaussian

    Add gmm and Update gaussian

    Added gmm and gaussian of the numpy models. In addition, updated the gaussian distribution's API.

    The API change is like following:

    Previous :

    batch_size = 10
    output_dim = 10
    input_shape = (batch_size, output_dim)
    mean = np.zeros(shape=input_shape)
    sigma = np.ones(shape=input_shape) * 5.
    ln_var = np.log(sigma) * 2.
    distribution = D.Gaussian(mean, ln_var)
    # return nn.Variable
    assert isinstance(distribution.sample(), nn.Variable)
    

    Updated:

    batch_size = 10
    output_dim = 10
    input_shape = (batch_size, output_dim)
    mean = np.zeros(shape=input_shape)
    sigma = np.ones(shape=input_shape) * 5.
    ln_var = np.log(sigma) * 2.
    # You have to pass the nn.Variable if you want to get nn.Variable as all class method's return.
    distribution = D.Gaussian(nn.Variable.from_numpy_array(mean), nn.Variable.from_numpy_array(ln_var))
    assert isinstance(distribution.sample(), nn.Variable)
    
    # If you pass np.ndarray, then all class methods return np.ndarray
    # Currently, only support without batch shape (i.e. mean.shape = (dims,), ln_var.shape = (dims, dims)).
    distribution = D.Gaussian(mean[0], np.diag(ln_var[0]))  # without batch
    assert isinstance(distribution.sample(), np.ndarray)
    
    opened by sbsekiguchi 0
  • Support nnabla-browser

    Support nnabla-browser

    • [x] add MonitorWriter
    • [x] save computational graph as nntxt

    example

    import gym
    
    import nnabla_rl.algorithms as A
    import nnabla_rl.hooks as H
    import nnabla_rl.writers as W
    from nnabla_rl.utils.evaluator import EpisodicEvaluator
    
    # save training computational graph
    training_graph_hook = H.TrainingGraphHook(outdir="test")
    
    # evaluation hook with nnabla's Monitor
    eval_env = gym.make("Pendulum-v0")
    evaluator = EpisodicEvaluator(run_per_evaluation=10)
    evaluation_hook = H.EvaluationHook(
        eval_env,
        evaluator,
        timing=10,
        writer=W.MonitorWriter(outdir="test", file_prefix='evaluation_result'),
    )
    
    env = gym.make("Pendulum-v0")
    sac = A.SAC(env)
    sac.set_hooks([training_graph_hook, evaluation_hook])
    
    sac.train_online(env, total_iterations=100)
    

    image image

    opened by ishihara-y 0
  • Add iLQR and LQR

    Add iLQR and LQR

    Implementation of Linear Quadratic Regulator (LQR) and iterative LQR algorithms.

    Co-authored-by: Yu Ishihara [email protected] Co-authored-by: Shunichi Sekiguchi [email protected]

    opened by ishihara-y 0
  • Check np_random instance and use correct randint alternative

    Check np_random instance and use correct randint alternative

    I am not sure when this change was made but in some environment, gym.unwrapped.np_random returns Generator instead of RandomState.

    # in case of RandomState
    # this line works
    gym.unwrapped.np_random.rand_int(...)
    # in case of Generator
    # rand_int does not exist and we must use integers as an alternative
    gym.unwrapped.np_random.integers(...)
    

    This PR will fix this issue and chooses correct function for sampling integers.

    opened by ishihara-y 0
  • Add icra2018 qtopt

    Add icra2018 qtopt

    opened by sbsekiguchi 0
Releases(v0.12.0)
Owner
Sony
Sony Group Corporation
Sony
YARSAW is an Async Python API Wrapper for the Random Stuff API.

Yet Another Random Stuff API Wrapper - YARSAW YARSAW is an Async Python API Wrapper for the Random Stuff API. This module makes it simpler for you to

Bruce 6 Mar 27, 2022
A continued fork of Disco

Orca Orca is an extensive and extendable Python 3.x library for the Discord API. orca boasts the following major features: Expressive, functional inte

RPS 4 Apr 03, 2022
Complete portable pipeline for masking of Aadhaar Number adhering to Govt. Privacy Guidelines.

Aadhaar Number Masking Pipeline Implementation of a complete pipeline that masks the Aadhaar Number in given images to adhere to Govt. of India's Priv

1 Nov 06, 2021
JAWS Pankration 2021 - DDD on AWS Lambda sample

JAWS Pankration 2021 - DDD on AWS Lambda sample What is this project? This project contains sample code for AWS Lambda with domain models. I presented

Atsushi Fukui 21 Mar 30, 2022
Telegram bot to trim and download videos from youtube.

Inline-YouTube-Trim-Bot Telegram bot to trim and download youtube videos Deploy You can deploy this bot anywhere. Demo - YouTubeBot Required Variables

SUBIN 56 Dec 11, 2022
Scheduled Block Checker for Cardano Stakepool Operators

ScheduledBlocks Scheduled Block Checker for Cardano Stakepool Operators Lightweight and Portable Scheduled Blocks Checker for Current Epoch. No cardan

SNAKE (Cardano Stakepool) 4 Oct 18, 2022
Live Coding - Mensageria na AWS com Amazon SNS e Amazon SQS

Live Coding - Mensageria na AWS com Amazon SNS e Amazon SQS Repositório para o Live Coding do dia 08/12/2021 Serviços utilizados Amazon SNS Amazon SQS

Cassiano Ricardo de Oliveira Peres 3 Mar 01, 2022
A simple python discord bot which give you a yogurt brand name, basing on a large database often updated.

YaourtBot A discord simple bot by Lopinosaurus Before using this code : ・Move env file to .env ・Change the channel ID on line 38 of bot.py to your #pi

The only one bunny who can dev. 0 May 09, 2022
Dashbot is an application for showing the trade in U.S. Trade Market

Dashbot is an application for showing the trade in U.S. Trade Market (e.g., 4:00am to 8:00pm Eastern Time for the US market using Intraday data using Mercury..

Ahmed Nabil 2 Jan 27, 2022
Automatically send commands to send Twitch followers to any Twitch account.

Automatically send commands to send Twitch followers to any Twitch account. You just need to be in a Twitch follow bot Discord server!

Thomas Keig 6 Nov 27, 2022
Mark Sullivan 66 Dec 13, 2022
Handles SDVX EXCEED GEAR result screen photos and attempts to read it.

Handles SDVX EXCEED GEAR result screen photos and attempts to read it.

silverhawke 1 Jan 08, 2022
A website application running in Google app engine, deliver rss news to your kindle. generate mobi using python, multilanguages supported.

Readme of english version refers to Readme_EN.md 简介 这是一个运行在Google App Engine(GAE)上的Kindle个人推送服务应用,生成排版精美的杂志模式mobi/epub格式自动每天推送至您的Kindle或其他邮箱。 此应用目前的主要

2.6k Jan 06, 2023
A Powerfull Userbot Telegram PandaX_Userbot, Vc Music Userbot + Bot Manager based Telethon

Support ☑ CREDITS THANKS YOU VERRY MUCH FOR ALL Telethon Pyrogram TeamUltroid TeamUserge CatUserbot pytgcalls Dan Lainnya

22 Dec 25, 2022
A simple telegram bot to help you to remove forward tag from post from any messages . Maded in python3 using @Pyrogram . Developed by @Kunal-Diwan

Frwd-Tag-Remover Telegram Bot to Remove forward tag from any Post . If you need any more modes in repo or If you find out any bugs, mention in @Develo

Kunal Diwan 2 Oct 14, 2022
Kali Kush - Account Nuker Tool

Kali Kush - Account Nuker Tool This is a discord tool made by me, and SSL :) antho#1731 How to use? pip3 install -r requirements.txt -py kalikush.py -

ryan 3 Dec 21, 2021
This project checks the weather in the next 12 hours and sends an SMS to your phone number if it's going to rain to remind you to take your umbrella.

RainAlert-Request-Twilio This project checks the weather in the next 12 hours and sends an SMS to your phone number if it's going to rain to remind yo

9 Apr 15, 2022
A link shortner telegram bot version 2 with advanced features

URL-Shortner-Bot-V2 A link shortner telegram bot version 2 with advanced features Made with Python3 (C) @FayasNoushad Copyright permission under MIT L

Fayas Noushad 18 Dec 29, 2022
A working bypass for discord gc spamming

IllusionGcSpammer A working bypass for discord gc spamming Installation Run pip install pip install DiscordGcSpammer then your good to go. Usage You c

6 Sep 30, 2022
Скрипт, позволяющий импортировать плейлисты из Spotify, а также обычные треклисты в VK музыку.

vk-music-import Программа для переноса плейлистов из Spotify и текстовых треклистов в VK Музыку. Преимущества: Позволяет быстро импортировать плейлист

Mew Forest 32 Nov 23, 2022