Deep Reinforcement Learning with pytorch & visdom

Overview

Deep Reinforcement Learning with

pytorch & visdom


  • Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A3C on InvertedPendulum(MuJoCo)):
  • Sample on-line plotting while training an A3C agent on Pong (with 16 learner processes): a3c_pong_plot

  • Sample loggings while training a DQN agent on CartPole (we use WARNING as the logging level currently to get rid of the INFO printouts from visdom):

[WARNING ] (MainProcess) <===================================>
[WARNING ] (MainProcess) bash$: python -m visdom.server
[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900
[WARNING ] (MainProcess) <===================================> DQN
[WARNING ] (MainProcess) <-----------------------------------> Env
[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123
[INFO    ] (MainProcess) Making new env: CartPole-v0
[WARNING ] (MainProcess) Action Space: [0, 1]
[WARNING ] (MainProcess) State  Space: 4
[WARNING ] (MainProcess) <-----------------------------------> Model
[WARNING ] (MainProcess) MlpModel (
  (fc1): Linear (4 -> 16)
  (rl1): ReLU ()
  (fc2): Linear (16 -> 16)
  (rl2): ReLU ()
  (fc3): Linear (16 -> 16)
  (rl3): ReLU ()
  (fc4): Linear (16 -> 2)
)
[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.
[WARNING ] (MainProcess) <===================================> Training ...
[WARNING ] (MainProcess) Validation Data @ Step: 501
[WARNING ] (MainProcess) Start  Training @ Step: 501
[WARNING ] (MainProcess) Reporting       @ Step: 2500 | Elapsed Time: 5.32397913933
[WARNING ] (MainProcess) Training Stats:   epsilon:          0.972
[WARNING ] (MainProcess) Training Stats:   total_reward:     2500.0
[WARNING ] (MainProcess) Training Stats:   avg_reward:       21.7391304348
[WARNING ] (MainProcess) Training Stats:   nepisodes:        115
[WARNING ] (MainProcess) Training Stats:   nepisodes_solved: 114
[WARNING ] (MainProcess) Training Stats:   repisodes_solved: 0.991304347826
[WARNING ] (MainProcess) Evaluating      @ Step: 2500
[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539
[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488
[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107
[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106
[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607
[WARNING ] (MainProcess) Saving Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...
[WARNING ] (MainProcess) Saved  Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.
[WARNING ] (MainProcess) Resume Training @ Step: 2500
...

What is included?

This repo currently contains the following agents:

  • Deep Q Learning (DQN) [1], [2]
  • Double DQN [3]
  • Dueling network DQN (Dueling DQN) [4]
  • Asynchronous Advantage Actor-Critic (A3C) (w/ both discrete/continuous action space support) [5], [6]
  • Sample Efficient Actor-Critic with Experience Replay (ACER) (currently w/ discrete action space support (Truncated Importance Sampling, 1st Order TRPO)) [7], [8]

Work in progress:

  • Testing ACER

Future Plans:

  • Deep Deterministic Policy Gradient (DDPG) [9], [10]
  • Continuous DQN (CDQN or NAF) [11]

Code structure & Naming conventions:

NOTE: we follow the exact code structure as pytorch-dnc so as to make the code easily transplantable.

  • ./utils/factory.py

We suggest the users refer to ./utils/factory.py, where we list all the integrated Env, Model, Memory, Agent into Dict's. All of those four core classes are implemented in ./core/. The factory pattern in ./utils/factory.py makes the code super clean, as no matter what type of Agent you want to train, or which type of Env you want to train on, all you need to do is to simply modify some parameters in ./utils/options.py, then the ./main.py will do it all (NOTE: this ./main.py file never needs to be modified).

  • namings

To make the code more clean and readable, we name the variables using the following pattern (mainly in inherited Agent's):

  • *_vb: torch.autograd.Variable's or a list of such objects
  • *_ts: torch.Tensor's or a list of such objects
  • otherwise: normal python datatypes

Dependencies


How to run:

You only need to modify some parameters in ./utils/options.py to train a new configuration.

  • Configure your training in ./utils/options.py:
  • line 14: add an entry into CONFIGS to define your training (agent_type, env_type, game, model_type, memory_type)
  • line 33: choose the entry you just added
  • line 29-30: fill in your machine/cluster ID (MACHINE) and timestamp (TIMESTAMP) to define your training signature (MACHINE_TIMESTAMP), the corresponding model file and the log file of this training will be saved under this signature (./models/MACHINE_TIMESTAMP.pth & ./logs/MACHINE_TIMESTAMP.log respectively). Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash: python -m visdom.server &, then open this address in your browser: http://localhost:8097/env/MACHINE_TIMESTAMP)
  • line 32: to train a model, set mode=1 (training visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP); to test the model of this current training, all you need to do is to set mode=2 (testing visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP_test).
  • Run:

python main.py


Bonus Scripts :)

We also provide 2 additional scripts for quickly evaluating your results after training. (Dependecies: lmj-plot)

  • plot.sh (e.g., plot from log file: logs/machine1_17080801.log)
  • ./plot.sh machine1 17080801
  • the generated figures will be saved into figs/machine1_17080801/
  • plot_compare.sh (e.g., compare log files: logs/machine1_17080801.log,logs/machine2_17080802.log)

./plot.sh 00 machine1 17080801 machine2 17080802

  • the generated figures will be saved into figs/compare_00/
  • the color coding will be in the order of: red green blue magenta yellow cyan

Repos we referred to during the development of this repo:


Citation

If you find this library useful and would like to cite it, the following would be appropriate:

@misc{pytorch-rl,
  author = {Zhang, Jingwei and Tai, Lei},
  title = {jingweiz/pytorch-rl},
  url = {https://github.com/jingweiz/pytorch-rl},
  year = {2017}
}
Owner
Jingwei Zhang
Jingwei Zhang
code for TCL: Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022

Vision-Language Pre-Training with Triple Contrastive Learning, CVPR 2022 News (03/16/2022) upload retrieval checkpoints finetuned on COCO and Flickr T

187 Jan 02, 2023
Improving Factual Completeness and Consistency of Image-to-text Radiology Report Generation

Improving Factual Completeness and Consistency of Image-to-text Radiology Report Generation The reference code of Improving Factual Completeness and C

46 Dec 15, 2022
Our CIKM21 Paper "Incorporating Query Reformulating Behavior into Web Search Evaluation"

Reformulation-Aware-Metrics Introduction This codebase contains source-code of the Python-based implementation of our CIKM 2021 paper. Chen, Jia, et a

xuanyuan14 5 Mar 05, 2022
f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation

f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation [Paper] [PyTorch] [MXNet] [Video] This repository provides code for training

Visual Understanding Lab @ Samsung AI Center Moscow 516 Dec 21, 2022
InvTorch: memory-efficient models with invertible functions

InvTorch: Memory-Efficient Invertible Functions This module extends the functionality of torch.utils.checkpoint.checkpoint to work with invertible fun

Modar M. Alfadly 12 May 12, 2022
Py4fi2nd - Jupyter Notebooks and code for Python for Finance (2nd ed., O'Reilly) by Yves Hilpisch.

Python for Finance (2nd ed., O'Reilly) This repository provides all Python codes and Jupyter Notebooks of the book Python for Finance -- Mastering Dat

Yves Hilpisch 1k Jan 05, 2023
Parallel and High-Fidelity Text-to-Lip Generation; AAAI 2022 ; Official code

Parallel and High-Fidelity Text-to-Lip Generation This repository is the official PyTorch implementation of our AAAI-2022 paper, in which we propose P

Zhying 77 Dec 21, 2022
using yolox+deepsort for object-tracker

YOLOX_deepsort_tracker yolox+deepsort实现目标跟踪 最新的yolox尝尝鲜~~(yolox正处在频繁更新阶段,因此直接链接yolox仓库作为子模块) Install Clone the repository recursively: git clone --rec

245 Dec 26, 2022
unet for image segmentation

Implementation of deep learning framework -- Unet, using Keras The architecture was inspired by U-Net: Convolutional Networks for Biomedical Image Seg

zhixuhao 4.1k Dec 31, 2022
Linear algebra python - Number of operations and problems in Linear Algebra and Numerical Linear Algebra

Linear algebra in python Number of operations and problems in Linear Algebra and

Alireza 5 Oct 09, 2022
Scheme for training and applying a label propagation framework

Factorisation-based Image Labelling Overview This is a scheme for training and applying the factorisation-based image labelling (FIL) framework. Some

Wellcome Centre for Human Neuroimaging 2 Dec 17, 2021
Convolutional Neural Network to detect deforestation in the Amazon Rainforest

Convolutional Neural Network to detect deforestation in the Amazon Rainforest This project is part of my final work as an Aerospace Engineering studen

5 Feb 17, 2022
Repo for code associated with Modeling the Mitral Valve.

Project Title Mitral Valve Getting Started Repo for code associated with Modeling the Mitral Valve. See https://arxiv.org/abs/1902.00018 for preprint,

Alex Kaiser 1 May 17, 2022
Integrated physics-based and ligand-based modeling.

ComBind ComBind integrates data-driven modeling and physics-based docking for improved binding pose prediction and binding affinity prediction. Given

Dror Lab 44 Oct 26, 2022
3.8% and 18.3% on CIFAR-10 and CIFAR-100

Wide Residual Networks This code was used for experiments with Wide Residual Networks (BMVC 2016) http://arxiv.org/abs/1605.07146 by Sergey Zagoruyko

Sergey Zagoruyko 1.2k Dec 29, 2022
Nightmare-Writeup - Writeup for the Nightmare CTF Challenge from 2022 DiceCTF

Nightmare: One Byte to ROP // Alternate Solution TLDR: One byte write, no leak.

1 Feb 17, 2022
ToFFi - Toolbox for Frequency-based Fingerprinting of Brain Signals

ToFFi Toolbox This repository contains "before peer review" version of the software related to the preprint of the publication ToFFi - Toolbox for Fre

4 Aug 31, 2022
Pytorch implementation of

EfficientTTS Unofficial Pytorch implementation of "EfficientTTS: An Efficient and High-Quality Text-to-Speech Architecture"(arXiv). Disclaimer: Somebo

Liu Songxiang 109 Nov 16, 2022
Data-driven reduced order modeling for nonlinear dynamical systems

SSMLearn Data-driven Reduced Order Models for Nonlinear Dynamical Systems This package perform data-driven identification of reduced order model based

Haller Group, Nonlinear Dynamics 27 Dec 13, 2022
3 Apr 20, 2022