Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

Related tags

Deep Learningcql-jax
Overview

CQL-JAX

This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on top of the SAC base of JAX-RL.

Usage

Install Dependencies-

pip install -r requirements.txt
pip install "jax[cuda111]<=0.21.1" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Run CQL-

python train_offline.py --env_name=hopper-expert-v0 --min_q_weight=5

Please use the following values of min_q_weight on MuJoCo tasks to reproduce CQL results from IQL paper-

Domain medium medium-replay medium-expert
walker 10 1 10
hopper 5 5 1
cheetah 90 80 100

For antmaze tasks min_q_weight=10 is found to work best.

In case of Out-Of Memory errors in JAX, try running with the following env variables-

XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 python ...
XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 python ...

Performance & Runtime

Returns are more or less same as the torch implementation and comparable to IQL-

Task CQL(PyTorch) CQL(JAX) IQL
hopper-medium-v2 58.5 74.6 66.3
hopper-medium-replay-v2 95.0 92.1 94.7
hopper-medium-expert-v2 105.4 83.2 91.5
antmaze-umaze-v0 74.0 69.5 87.5
antmaze-umaze-diverse-v0 84.0 78.7 62.2
antmaze-medium-play-v0 61.2 14.2 71.2
antmaze-medium-diverse-v0 53.7 10.7 70.2
antmaze-large-play-v0 15.8 0.0 39.6
antmaze-large-diverse-v0 14.9 0.0 47.5

Wall-clock time averages to ~50 mins, improving over IQL paper's 80 min CQL and closing the gap with IQL's 20 min.

Task CQL(JAX) IQL
hopper-medium-v2 52 27
hopper-medium-replay-v2 54 30
hopper-medium-expert-v2 57 29

Time efficiency over the original torch implementation is more than 4 times.

For more offline RL algorithm implementations, check out the JAX-RL, IQL and rlkit repositories.

Citation

In case you use CQL-JAX for your research, please cite the following-

@misc{cqljax,
  author = {Suri, Karush},
  title = {{Conservative Q Learning in JAX.}},
  url = {https://github.com/karush17/cql-jax},
  year = {2021}
}

References

Owner
Karush Suri
Deep Learning Researcher at Huawei Noah's Ark Lab, Toronto.
Karush Suri
Deep learning models for classification of 15 common weeds in the southern U.S. cotton production systems.

CottonWeeds Deep learning models for classification of 15 common weeds in the southern U.S. cotton production systems. requirements pytorch torchsumma

Dong Chen 8 Jun 07, 2022
PartImageNet is a large, high-quality dataset with part segmentation annotations

PartImageNet: A Large, High-Quality Dataset of Parts We will release our dataset and scripts soon after cleaning and approval. Introduction PartImageN

Ju He 77 Nov 30, 2022
Python Wrapper for Embree

pyembree Python Wrapper for Embree Installation You can install pyembree (and embree) via the conda-forge package. $ conda install -c conda-forge pyem

Anthony Scopatz 67 Dec 24, 2022
This porject is intented to build the most accurate model for predicting the porbability of loan default

Estimating-Loan-Default-Probability IBA ML2 Mid-project / Kaggle Competition This porject is intented to build the most accurate model for predicting

Adil Gahramanov 1 Jan 24, 2022
Official implementation of "Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform", ICCV 2021

Variable-Rate Deep Image Compression through Spatially-Adaptive Feature Transform This repository is the implementation of "Variable-Rate Deep Image C

Myungseo Song 47 Dec 13, 2022
dualPC.R contains the R code for the main functions.

dualPC.R contains the R code for the main functions. dualPC_sim.R contains an example run with the different PC versions; it calls dualPC_algs.R whic

3 May 30, 2022
Dynamic Multi-scale Filters for Semantic Segmentation (DMNet ICCV'2019)

Dynamic Multi-scale Filters for Semantic Segmentation (DMNet ICCV'2019) Introduction Official implementation of Dynamic Multi-scale Filters for Semant

23 Oct 21, 2022
This is an official implementation for "DeciWatch: A Simple Baseline for 10x Efficient 2D and 3D Pose Estimation"

DeciWatch: A Simple Baseline for 10× Efficient 2D and 3D Pose Estimation This repo is the official implementation of "DeciWatch: A Simple Baseline for

117 Dec 24, 2022
This is an example of object detection on Micro bacterium tuberculosis using Mask-RCNN

Mask-RCNN on Mycobacterium tuberculosis This is an example of object detection on Mycobacterium Tuberculosis using Mask RCNN. Implement of Mask R-CNN

Jun-En Ding 1 Sep 16, 2021
Learning to Predict Gradients for Semi-Supervised Continual Learning

Learning to Predict Gradients for Semi-Supervised Continual Learning Code for project: "Learning to Predict Gradients for Semi-Supervised Continual Le

Yan Luo 2 Mar 05, 2022
Ascend your Jupyter Notebook usage

Jupyter Ascending Sync Jupyter Notebooks from any editor About Jupyter Ascending lets you edit Jupyter notebooks from your favorite editor, then insta

Untitled AI 254 Jan 08, 2023
[NeurIPS'21] Projected GANs Converge Faster

[Project] [PDF] [Supplementary] [Talk] This repository contains the code for our NeurIPS 2021 paper "Projected GANs Converge Faster" by Axel Sauer, Ka

798 Jan 04, 2023
BrainGNN - A deep learning model for data-driven discovery of functional connectivity

A deep learning model for data-driven discovery of functional connectivity https://doi.org/10.3390/a14030075 Usman Mahmood, Zengin Fu, Vince D. Calhou

Usman Mahmood 3 Aug 28, 2022
Solutions of Reinforcement Learning 2nd Edition

Solutions of Reinforcement Learning, An Introduction

YIFAN WANG 1.4k Dec 30, 2022
Weight initialization schemes for PyTorch nn.Modules

nninit Weight initialization schemes for PyTorch nn.Modules. This is a port of the popular nninit for Torch7 by @kaixhin. ##Update This repo has been

Alykhan Tejani 69 Jan 26, 2021
Datasets and pretrained Models for StyleGAN3 ...

Datasets and pretrained Models for StyleGAN3 ... Dear arfiticial friend, this is a collection of artistic datasets and models that we have put togethe

lucid layers 34 Oct 06, 2022
Joint deep network for feature line detection and description

SOLD² - Self-supervised Occlusion-aware Line Description and Detection This repository contains the implementation of the paper: SOLD² : Self-supervis

Computer Vision and Geometry Lab 427 Dec 27, 2022
Cleaned up code for DSTC 10: SIMMC 2.0 track: subtask 2: multimodal coreference resolution

UNITER-Based Situated Coreference Resolution with Rich Multimodal Input: arXiv MMCoref_cleaned Code for the MMCoref task of the SIMMC 2.0 dataset. Pre

Yichen (William) Huang 2 Dec 05, 2022
Kaggle Feedback Prize - Evaluating Student Writing 15th solution

Kaggle Feedback Prize - Evaluating Student Writing 15th solution First of all, I would like to thank the excellent notebooks and discussions from http

Lingyuan Zhang 6 Mar 24, 2022
L-Verse: Bidirectional Generation Between Image and Text

Far beyond learning long-range interactions of natural language, transformers are becoming the de-facto standard for many vision tasks with their power and scalabilty

Kim, Taehoon 102 Dec 21, 2022