Causal Influence Detection for Improving Efficiency in Reinforcement Learning

Overview

Causal Influence Detection for Improving Efficiency in Reinforcement Learning

This repository contains the code release for the paper "Causal Influence Detection for Improving Efficiency in Reinforcement Learning", published at NeurIPS 2021.

This work was done by Maximilian Seitzer, Bernhard Schölkopf and Georg Martius at the Autonomous Learning Group, Max-Planck Institute for Intelligent Systems.

If you make use of our work, please use the citation information below.

Abstract

Many reinforcement learning (RL) environments consist of independent entities that interact sparsely. In such environments, RL agents have only limited influence over other entities in any particular situation. Our idea in this work is that learning can be efficiently guided by knowing when and what the agent can influence with its actions. To achieve this, we introduce a measure of situation-dependent causal influence based on conditional mutual information and show that it can reliably detect states of influence. We then propose several ways to integrate this measure into RL algorithms to improve exploration and off-policy learning. All modified algorithms show strong increases in data efficiency on robotic manipulation tasks.

Setup

Use make_conda_env.sh to create a Conda environment with minimal dependencies:

./make_conda_env.sh minimal cid_in_rl

or recreate the environment used to get the results (more dependencies than necessary):

conda env create -f orig_environment.yml

Activate the environment with conda activate cid_in_rl.

Experiments

Causal Influence Detection

To reproduce the causal influence detection experiment, you will need to download the used datasets here. Extract them into the folder data/. The most simple way to run all experiments is to use the included Makefile (this will take a long time):

make -C experiments/1-influence

The results will be in the folder ./data/experiments/1-influence/.

You can also train a single model, for example

python -m cid.influence_estimation.train_model \
        --log-dir logs/eval_fetchpickandplace 
        --no-logging-subdir --seed 0 \
        --memory-path data/fetchpickandplace/memory_5k_her_agent_v2.npy \
        --val-memory-path data/fetchpickandplace/val_memory_2kof5k_her_agent_v2.npy \
        experiments/1-influence/pickandplace_model_gaussian.gin

which will train a model on FetchPickPlace, and put the results in logs/eval_fetchpickandplace.

To evaluate the CAI score performance of the model on the validation set, use

python experiments/1-influence/pickandplace_cmi.py 
    --output-path logs/eval_fetchpickandplace 
    --model-path logs/eval_fetchpickandplace
    --settings-path logs/eval_fetchpickandplace/eval_settings.gin \
    --memory-path data/fetchpickandplace/val_memory_2kof5k_her_agent_v2.npy 
    --variants var_prod_approx

Reinforcement Learning

The RL experiments can be reproduced using the settings in experiments/2-prioritization, experiments/3-exploration, experiments/4-other.

To do so, run

python -m cid.train 
   

   

By default, the output will be in the folder ./logs.

Codebase Overview

  • cid/algorithms/ddpg_agent.py contains the DDPG agent
  • cid/envs contains new environments
    • cid/envs/one_d_slide.py implements the 1D-Slide dataset
    • cid/envs/robotics/pick_and_place_rot_table.py implements the RotatingTable environment
    • cid/envs/robotics/fetch_control_detection.py contains the code for deriving ground truth control labels for FetchPickAndPlace
  • cid/influence_estimation contains code for model training, evaluation and computing the causal influence score
    • cid/influence_estimation/train_model.py is the main model training script
    • cid/influence_estimation/eval_influence.py evaluates a trained model for its classification performance
    • cid/influence_estimation/transition_scorers contains code for computing the CAI score
  • cid/memory/ contains the replay buffers, which handle prioritization and exploration bonuses
    • cid/memory/mbp implements CAI (ours)
    • cid/memory/her implements Hindsight Experience Replay
    • cid/memory/ebp implements Energy-Based Hindsight Experience Prioritization
    • cid/memory/per implements Prioritized Experience Replay
  • cid/models contains Pytorch model implementations
    • cid/bnn.py contains the implementation of VIME
  • cid/play.py lets a trained RL agent run in an environment
  • cid/train.py is the main RL training script

Citation

Please use the following citation if you make use of our work:

@inproceedings{Seitzer2021CID,
  title = {Causal Influence Detection for Improving Efficiency in Reinforcement Learning},
  author = {Seitzer, Maximilian and Sch{\"o}lkopf, Bernhard and Martius, Georg},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS 2021)},
  month = dec,
  year = {2021},
  url = {https://arxiv.org/abs/2106.03443},
  month_numeric = {12}
}

License

This implementation is licensed under the MIT license.

The robotics environments were adapted from OpenAI Gym under MIT license. The VIME implementation was adapted from https://github.com/alec-tschantz/vime under MIT license.

Owner
Autonomous Learning Group
Autonomous Learning Group
Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes.

Polygon-Yolov5 This repository is based on Ultralytics/yolov5, with adjustments to enable polygon prediction boxes. Section I. Description The codes a

xinzelee 226 Jan 05, 2023
Sleep staging from ECG, assisted with EEG

Sleep_Staging_Knowledge Distillation This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep

2 Dec 12, 2022
A graph-to-sequence model for one-step retrosynthesis and reaction outcome prediction.

Graph2SMILES A graph-to-sequence model for one-step retrosynthesis and reaction outcome prediction. 1. Environmental setup System requirements Ubuntu:

29 Nov 18, 2022
TensorFlow ROCm port

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

ROCm Software Platform 622 Jan 09, 2023
"Reinforcement Learning for Bandit Neural Machine Translation with Simulated Human Feedback"

This is code repo for our EMNLP 2017 paper "Reinforcement Learning for Bandit Neural Machine Translation with Simulated Human Feedback", which implements the A2C algorithm on top of a neural encoder-

Khanh Nguyen 131 Oct 21, 2022
This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to generate a dynamic forecast from your own data.

📈 Automated Time Series Forecasting Background: This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to gene

Zach Renwick 42 Jan 04, 2023
Agile SVG maker for python

Agile SVG Maker Need to draw hundreds of frames for a GIF? Need to change the style of all pictures in a PPT? Need to draw similar images with differe

SemiWaker 4 Sep 25, 2022
NeuroGen: activation optimized image synthesis for discovery neuroscience

NeuroGen: activation optimized image synthesis for discovery neuroscience NeuroGen is a framework for synthesizing images that control brain activatio

3 Aug 17, 2022
[CVPR 2021 Oral] ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis

ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis [arxiv|pdf|v

Yinan He 78 Dec 22, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 27, 2022
The implemention of Video Depth Estimation by Fusing Flow-to-Depth Proposals

Flow-to-depth (FDNet) video-depth-estimation This is the implementation of paper Video Depth Estimation by Fusing Flow-to-Depth Proposals Jiaxin Xie,

32 Jun 14, 2022
Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021)

Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021) Citation Please cite as: @inproceedings{liu2020understan

Sunbow Liu 22 Nov 25, 2022
A library for implementing Decentralized Graph Neural Network algorithms.

decentralized-gnn A package for implementing and simulating decentralized Graph Neural Network algorithms for classification of peer-to-peer nodes. De

Multimedia Knowledge and Social Analytics Lab 5 Nov 07, 2022
The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding.

SuperGen The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding. Requirements Before running, you

Yu Meng 38 Dec 12, 2022
Domain Generalization with MixStyle, ICLR'21.

MixStyle This repo contains the code of our ICLR'21 paper, "Domain Generalization with MixStyle". The OpenReview link is https://openreview.net/forum?

Kaiyang 208 Dec 28, 2022
Code for the paper "Implicit Representations of Meaning in Neural Language Models"

Implicit Representations of Meaning in Neural Language Models Preliminaries Create and set up a conda environment as follows: conda create -n state-pr

Belinda Li 39 Nov 03, 2022
WaveFake: A Data Set to Facilitate Audio DeepFake Detection

WaveFake: A Data Set to Facilitate Audio DeepFake Detection This is the code repository for our NeurIPS 2021 (Track on Datasets and Benchmarks) paper

Chair for Sys­tems Se­cu­ri­ty 27 Dec 22, 2022
Deep Ensemble Learning with Jet-Like architecture

Ransomware analysis using DEL with jet-like architecture comprising two CNN wings, a sparse AE tail, a non-linear PCA to produce a diverse feature space, and an MLP nose

Ahsen Nazir 2 Feb 06, 2022
Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization

Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization This repository contains the source code for the paper (link wi

Rakuten Group, Inc. 0 Nov 19, 2021