Causal Influence Detection for Improving Efficiency in Reinforcement Learning

Overview

Causal Influence Detection for Improving Efficiency in Reinforcement Learning

This repository contains the code release for the paper "Causal Influence Detection for Improving Efficiency in Reinforcement Learning", published at NeurIPS 2021.

This work was done by Maximilian Seitzer, Bernhard Schölkopf and Georg Martius at the Autonomous Learning Group, Max-Planck Institute for Intelligent Systems.

If you make use of our work, please use the citation information below.

Abstract

Many reinforcement learning (RL) environments consist of independent entities that interact sparsely. In such environments, RL agents have only limited influence over other entities in any particular situation. Our idea in this work is that learning can be efficiently guided by knowing when and what the agent can influence with its actions. To achieve this, we introduce a measure of situation-dependent causal influence based on conditional mutual information and show that it can reliably detect states of influence. We then propose several ways to integrate this measure into RL algorithms to improve exploration and off-policy learning. All modified algorithms show strong increases in data efficiency on robotic manipulation tasks.

Setup

Use make_conda_env.sh to create a Conda environment with minimal dependencies:

./make_conda_env.sh minimal cid_in_rl

or recreate the environment used to get the results (more dependencies than necessary):

conda env create -f orig_environment.yml

Activate the environment with conda activate cid_in_rl.

Experiments

Causal Influence Detection

To reproduce the causal influence detection experiment, you will need to download the used datasets here. Extract them into the folder data/. The most simple way to run all experiments is to use the included Makefile (this will take a long time):

make -C experiments/1-influence

The results will be in the folder ./data/experiments/1-influence/.

You can also train a single model, for example

python -m cid.influence_estimation.train_model \
        --log-dir logs/eval_fetchpickandplace 
        --no-logging-subdir --seed 0 \
        --memory-path data/fetchpickandplace/memory_5k_her_agent_v2.npy \
        --val-memory-path data/fetchpickandplace/val_memory_2kof5k_her_agent_v2.npy \
        experiments/1-influence/pickandplace_model_gaussian.gin

which will train a model on FetchPickPlace, and put the results in logs/eval_fetchpickandplace.

To evaluate the CAI score performance of the model on the validation set, use

python experiments/1-influence/pickandplace_cmi.py 
    --output-path logs/eval_fetchpickandplace 
    --model-path logs/eval_fetchpickandplace
    --settings-path logs/eval_fetchpickandplace/eval_settings.gin \
    --memory-path data/fetchpickandplace/val_memory_2kof5k_her_agent_v2.npy 
    --variants var_prod_approx

Reinforcement Learning

The RL experiments can be reproduced using the settings in experiments/2-prioritization, experiments/3-exploration, experiments/4-other.

To do so, run

python -m cid.train 
   

   

By default, the output will be in the folder ./logs.

Codebase Overview

  • cid/algorithms/ddpg_agent.py contains the DDPG agent
  • cid/envs contains new environments
    • cid/envs/one_d_slide.py implements the 1D-Slide dataset
    • cid/envs/robotics/pick_and_place_rot_table.py implements the RotatingTable environment
    • cid/envs/robotics/fetch_control_detection.py contains the code for deriving ground truth control labels for FetchPickAndPlace
  • cid/influence_estimation contains code for model training, evaluation and computing the causal influence score
    • cid/influence_estimation/train_model.py is the main model training script
    • cid/influence_estimation/eval_influence.py evaluates a trained model for its classification performance
    • cid/influence_estimation/transition_scorers contains code for computing the CAI score
  • cid/memory/ contains the replay buffers, which handle prioritization and exploration bonuses
    • cid/memory/mbp implements CAI (ours)
    • cid/memory/her implements Hindsight Experience Replay
    • cid/memory/ebp implements Energy-Based Hindsight Experience Prioritization
    • cid/memory/per implements Prioritized Experience Replay
  • cid/models contains Pytorch model implementations
    • cid/bnn.py contains the implementation of VIME
  • cid/play.py lets a trained RL agent run in an environment
  • cid/train.py is the main RL training script

Citation

Please use the following citation if you make use of our work:

@inproceedings{Seitzer2021CID,
  title = {Causal Influence Detection for Improving Efficiency in Reinforcement Learning},
  author = {Seitzer, Maximilian and Sch{\"o}lkopf, Bernhard and Martius, Georg},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS 2021)},
  month = dec,
  year = {2021},
  url = {https://arxiv.org/abs/2106.03443},
  month_numeric = {12}
}

License

This implementation is licensed under the MIT license.

The robotics environments were adapted from OpenAI Gym under MIT license. The VIME implementation was adapted from https://github.com/alec-tschantz/vime under MIT license.

Owner
Autonomous Learning Group
Autonomous Learning Group
Train the HRNet model on ImageNet

High-resolution networks (HRNets) for Image classification News [2021/01/20] Add some stronger ImageNet pretrained models, e.g., the HRNet_W48_C_ssld_

HRNet 866 Jan 04, 2023
Detection of PCBA defect

Detection_of_PCBA_defect Detection_of_PCBA_defect Use yolov5 to train. $pip install -r requirements.txt Detect.py will detect file(jpg,mp4...) in cu

6 Nov 28, 2022
Framework for training options with different attention mechanism and using them to solve downstream tasks.

Using Attention in HRL Framework for training options with different attention mechanism and using them to solve downstream tasks. Requirements GPU re

5 Nov 03, 2022
[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation

Few-shot 3D Point Cloud Semantic Segmentation Created by Na Zhao from National University of Singapore Introduction This repository contains the PyTor

117 Dec 27, 2022
OstrichRL: A Musculoskeletal Ostrich Simulation to Study Bio-mechanical Locomotion.

OstrichRL This is the repository accompanying the paper OstrichRL: A Musculoskeletal Ostrich Simulation to Study Bio-mechanical Locomotion. It contain

Vittorio La Barbera 51 Nov 17, 2022
Deep Reinforced Attention Regression for Partial Sketch Based Image Retrieval.

DARP-SBIR Intro This repository contains the source code implementation for ICDM submission paper Deep Reinforced Attention Regression for Partial Ske

2 Jan 09, 2022
Official implementation of GraphMask as presented in our paper Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking.

GraphMask This repository contains an implementation of GraphMask, the interpretability technique for graph neural networks presented in our ICLR 2021

Michael Schlichtkrull 29 Sep 02, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022
Code for MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks

MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks This is the code for the paper: MentorNet: Learning Data-Driven Curriculum fo

Google 302 Dec 23, 2022
Draw like Bob Ross using the power of Neural Networks (With PyTorch)!

Draw like Bob Ross using the power of Neural Networks! (+ Pytorch) Learning Process Visualization Getting started Install dependecies Requires python3

Kendrick Tan 116 Mar 07, 2022
Implementation of Diverse Semantic Image Synthesis via Probability Distribution Modeling

Diverse Semantic Image Synthesis via Probability Distribution Modeling (CVPR 2021) Paper Zhentao Tan, Menglei Chai, Dongdong Chen, Jing Liao, Qi Chu,

tzt 45 Nov 17, 2022
OptaPlanner wrappers for Python. Currently significantly slower than OptaPlanner in Java or Kotlin.

OptaPy is an AI constraint solver for Python to optimize the Vehicle Routing Problem, Employee Rostering, Maintenance Scheduling, Task Assignment, School Timetabling, Cloud Optimization, Conference S

OptaPy 211 Jan 02, 2023
Source code and data from the RecSys 2020 article "Carousel Personalization in Music Streaming Apps with Contextual Bandits" by W. Bendada, G. Salha and T. Bontempelli

Carousel Personalization in Music Streaming Apps with Contextual Bandits - RecSys 2020 This repository provides Python code and data to reproduce expe

Deezer 48 Jan 02, 2023
Implementation of the CVPR 2021 paper "Online Multiple Object Tracking with Cross-Task Synergy"

Online Multiple Object Tracking with Cross-Task Synergy This repository is the implementation of the CVPR 2021 paper "Online Multiple Object Tracking

54 Oct 15, 2022
Efficiently computes derivatives of numpy code.

Note: Autograd is still being maintained but is no longer actively developed. The main developers (Dougal Maclaurin, David Duvenaud, Matt Johnson, and

Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton 6.1k Jan 08, 2023
Unofficial PyTorch implementation of Guided Dropout

Unofficial PyTorch implementation of Guided Dropout This is a simple implementation of Guided Dropout for research. We try to reproduce the algorithm

2 Jan 07, 2022
LSTM model trained on a small dataset of 3000 names written in PyTorch

LSTM model trained on a small dataset of 3000 names. Model generates names from model by selecting one out of top 3 letters suggested by model at a time until an EOS (End Of Sentence) character is no

Sahil Lamba 1 Dec 20, 2021
TinyML Cookbook, published by Packt

TinyML Cookbook This is the code repository for TinyML Cookbook, published by Packt. Author: Gian Marco Iodice Publisher: Packt About the book This bo

Packt 93 Dec 29, 2022
Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression

Alpha-IoU: A Family of Power Intersection over Union Losses for Bounding Box Regression YOLOv5 with alpha-IoU losses implemented in PyTorch. Example r

Jacobi(Jiabo He) 147 Dec 05, 2022
An implementation of chunked, compressed, N-dimensional arrays for Python.

Zarr Latest Release Package Status License Build Status Coverage Downloads Gitter Citation What is it? Zarr is a Python package providing an implement

Zarr Developers 1.1k Dec 30, 2022