Invariant Causal Prediction for Block MDPs

Overview

MISA

Abstract

Generalization across environments is critical to the successful application of reinforcement learning algorithms to real-world challenges. In this paper, we consider the problem of learning abstractions that generalize in block MDPs, families of environments with a shared latent state space and dynamics structure over that latent space, but varying observations. We leverage tools from causal inference to propose a method of invariant prediction to learn model-irrelevance state abstractions (MISA) that generalize to novel observations in the multi-environment setting. We prove that for certain classes of environments, this approach outputs with high probability a state abstraction corresponding to the causal feature set with respect to the return. We further provide more general bounds on model error and generalization error in the multi-environment setting, in the process showing a connection between causal variable selection and the state abstraction framework for MDPs. We give empirical evidence that our methods work in both linear and nonlinear settings, attaining improved generalization over single-and multi-task baselines.

Citation

@inproceedings{zhang2020invariant,
    title={Invariant Causal Prediction for Block MDPs},
    author={Amy Zhang and Clare Lyle and Shagun Sodhani and Angelos Filos and Marta Kwiatkowska and Joelle Pineau and Yarin Gal and Doina Precup},
    year={2020},
    booktitle={International Conference on Machine Learning (ICML)},
}

Experiments

The three sets of experiments on model learning, imitation learning, and reinforcement learning can be found in their respective folder. To install requirements, create a new conda environment and run

pip install -e requirements.txt

In model learning, there are two sets of experiments, linear MISA and nonlinear MISA. The code is in model_learning. First cd model_learning.

The main experiment with linear MISA can be run with

ICPAbstractMDP.ipynb

The main experiment with nonlinear MISA can be run with

python main.py

For running the imitation learning experiments, first cd imitation_learning. Then install the baselines by running cd baselines && pip install tensorflow==1.14 && pip install -e . The main experiments can be run in imitation_learning directory with:

python train_expert.py --save_model --save_model_path models # Training the expert model

#Lets say the model was trained for 150K steps.

mkdir -p buffers/train/0 buffers/train/1 buffers/eval/0 # Directory to hold the buffer data

python collect_data_using_expert_policy.py --load_model_path models_150000 --save_buffer --save_buffer_path buffers  # Collecting the trajectories using the expert model

python train.py --use_single_encoder_decoder --num_train_envs 1 --num_eval_envs 1 --load_buffer_path buffers # MISA One Env

python train.py --use_single_encoder_decoder --num_train_envs 2 --num_eval_envs 1 --load_buffer_path buffers # Baseline One Decoder 

python train.py --use_discriminator --num_train_envs 2 --num_eval_envs 1 --load_buffer_path buffers # Proposed Approach

python train.py --use_irm_loss --num_train_envs 2 --num_eval_envs 1 --load_buffer_path buffers # IRM

In reinforcement learning, the main experiment can be run in reinforcement_learning directory with

./run_local.sh

LICENSE

Attribution-NonCommercial 4.0 International

Owner
Meta Research
Meta Research
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022
Matlab Python Heuristic Battery Opt - SMOP conversion and manual conversion

SMOP is Small Matlab and Octave to Python compiler. SMOP translates matlab to py

Tom Xu 1 Jan 12, 2022
Complete* list of autonomous driving related datasets

AD Datasets Complete* and curated list of autonomous driving related datasets Contributing Contributions are very welcome! To add or update a dataset:

Daniel Bogdoll 13 Dec 19, 2022
The FIRST GANs-based omics-to-omics translation framework

OmiTrans Please also have a look at our multi-omics multi-task DL freamwork 👀 : OmiEmbed The FIRST GANs-based omics-to-omics translation framework Xi

Xiaoyu Zhang 6 Dec 14, 2022
A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

Vikash Sehwag 65 Dec 19, 2022
A deep neural networks for images using CNN algorithm.

Example-CNN-Project This is a simple project showing how to implement deep neural networks using CNN algorithm. The dataset is taken from this link: h

Mohammad Amin Dadgar 3 Sep 16, 2022
Official Implementation and Dataset of "PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask and Group-Level Consistency", CVPR 2021

Portrait Photo Retouching with PPR10K Paper | Supplementary Material PPR10K: A Large-Scale Portrait Photo Retouching Dataset with Human-Region Mask an

184 Dec 11, 2022
NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size

NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size Xuanyi Dong, Lu Liu, Katarzyna Musial, Bogdan Gabrys in IEEE Transactions o

D-X-Y 137 Dec 20, 2022
JittorVis - Visual understanding of deep learning models

JittorVis: Visual understanding of deep learning model JittorVis is an open-source library for understanding the inner workings of Jittor models by vi

thu-vis 182 Jan 06, 2023
A novel pipeline framework for multi-hop complex KGQA task. About the paper title: Improving Multi-hop Embedded Knowledge Graph Question Answering by Introducing Relational Chain Reasoning

Rce-KGQA A novel pipeline framework for multi-hop complex KGQA task. This framework mainly contains two modules, answering_filtering_module and relati

金伟强 -上海大学人工智能小渣渣~ 16 Nov 18, 2022
Deep learning library for solving differential equations and more

DeepXDE Voting on whether we should have a Slack channel for discussion. DeepXDE is a library for scientific machine learning. Use DeepXDE if you need

Lu Lu 1.4k Dec 29, 2022
Object detection GUI based on PaddleDetection

PP-Tracking GUI界面测试版 本项目是基于飞桨开源的实时跟踪系统PP-Tracking开发的可视化界面 在PaddlePaddle中加入pyqt进行GUI页面研发,可使得整个训练过程可视化,并通过GUI界面进行调参,模型预测,视频输出等,通过多种类型的识别,简化整体预测流程。 GUI界面

杨毓栋 68 Jan 02, 2023
A Small and Easy approach to the BraTS2020 dataset (2D Segmentation)

BraTS2020 A Light & Scalable Solution to BraTS2020 | Medical Brain Tumor Segmentation (2D Segmentation) Developed the segmentation models for segregat

Gunjan Haldar 0 Jan 19, 2022
UIUCTF 2021 Public Challenge Repository

UIUCTF-2021-Public UIUCTF 2021 Public Challenge Repository Notes: every challenge folder contains a challenge.yml file in the format for ctfcli, CTFd'

SIGPwny 15 Nov 03, 2022
Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021, Pytorch)

S2VD Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021) Requirements and Dependencies Ubuntu 16.04, cuda 10.0 Python 3.6.10, P

Zongsheng Yue 53 Nov 23, 2022
Official implementation of "GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators" (NeurIPS 2020)

GS-WGAN This repository contains the implementation for GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators (NeurIPS

46 Nov 09, 2022
Machine Learning Platform for Kubernetes

Reproduce, Automate, Scale your data science. Welcome to Polyaxon, a platform for building, training, and monitoring large scale deep learning applica

polyaxon 3.2k Dec 23, 2022
Multi-Template Mouse Brain MRI Atlas (MBMA): both in-vivo and ex-vivo

Multi-template MRI mouse brain atlas (both in vivo and ex vivo) Mouse Brain MRI atlas (both in-vivo and ex-vivo) (repository relocated from the origin

8 Nov 18, 2022
Capture all information throughout your model's development in a reproducible way and tie results directly to the model code!

Rubicon Purpose Rubicon is a data science tool that captures and stores model training and execution information, like parameters and outcomes, in a r

Capital One 97 Jan 03, 2023
PerfFuzz: Automatically Generate Pathological Inputs for C/C++ programs

PerfFuzz Performance problems in software can arise unexpectedly when programs are provided with inputs that exhibit pathological behavior. But how ca

Caroline Lemieux 125 Nov 18, 2022