Contrastive Learning of Structured World Models

Related tags

Deep Learningc-swm
Overview

Contrastive Learning of Structured World Models

This repository contains the official PyTorch implementation of:

Contrastive Learning of Structured World Models.
Thomas Kipf, Elise van der Pol, Max Welling.
http://arxiv.org/abs/1911.12247

C-SWM

Abstract: A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.

Requirements

  • Python 3.6 or 3.7
  • PyTorch version 1.2
  • OpenAI Gym version: 0.12.0 pip install gym==0.12.0
  • OpenAI Atari_py version: 0.1.4: pip install atari-py==0.1.4
  • Scikit-image version 0.15.0 pip install scikit-image==0.15.0
  • Matplotlib version 3.0.2 pip install matplotlib==3.0.2

Generate datasets

2D Shapes:

python data_gen/env.py --env_id ShapesTrain-v0 --fname data/shapes_train.h5 --num_episodes 1000 --seed 1
python data_gen/env.py --env_id ShapesEval-v0 --fname data/shapes_eval.h5 --num_episodes 10000 --seed 2

3D Cubes:

python data_gen/env.py --env_id CubesTrain-v0 --fname data/cubes_train.h5 --num_episodes 1000 --seed 3
python data_gen/env.py --env_id CubesEval-v0 --fname data/cubes_eval.h5 --num_episodes 10000 --seed 4

Atari Pong:

python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_eval.h5 --num_episodes 100 --atari --seed 2

Space Invaders:

python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_eval.h5 --num_episodes 100 --atari --seed 2

3-Body Gravitational Physics:

python data_gen/physics.py --num-episodes 5000 --fname data/balls_train.h5 --seed 1
python data_gen/physics.py --num-episodes 1000 --fname data/balls_eval.h5 --eval --seed 2

Run model training and evaluation

2D Shapes:

python train.py --dataset data/shapes_train.h5 --encoder small --name shapes
python eval.py --dataset data/shapes_eval.h5 --save-folder checkpoints/shapes --num-steps 1

3D Cubes:

python train.py --dataset data/cubes_train.h5 --encoder large --name cubes
python eval.py --dataset data/cubes_eval.h5 --save-folder checkpoints/cubes --num-steps 1

Atari Pong:

python train.py --dataset data/pong_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name pong
python eval.py --dataset data/pong_eval.h5 --save-folder checkpoints/pong --num-steps 1

Space Invaders:

python train.py --dataset data/spaceinvaders_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name spaceinvaders
python eval.py --dataset data/spaceinvaders_eval.h5 --save-folder checkpoints/spaceinvaders --num-steps 1

3-Body Gravitational Physics:

python train.py --dataset data/balls_train.h5 --encoder medium --embedding-dim 4 --num-objects 3 --ignore-action --name balls
python eval.py --dataset data/balls_eval.h5 --save-folder checkpoints/balls --num-steps 1

Cite

If you make use of this code in your own work, please cite our paper:

@article{kipf2019contrastive,
  title={Contrastive Learning of Structured World Models}, 
  author={Kipf, Thomas and van der Pol, Elise and Welling, Max}, 
  journal={arXiv preprint arXiv:1911.12247}, 
  year={2019} 
}
Owner
Thomas Kipf
Thomas Kipf
Code for the TPAMI paper: "Syntax Customized Video Captioning by Imitating Exemplar Sentences"

Syntax-Customized-Video-Captioning Code for the TPAMI paper: "Syntax Customized Video Captioning by Imitating Exemplar Sentences". This is my second w

3 Dec 05, 2022
Urban mobility simulations with Python3, RLlib (Deep Reinforcement Learning) and Mesa (Agent-based modeling)

Deep Reinforcement Learning for Smart Cities Documentation RLlib: https://docs.ray.io/en/master/rllib.html Mesa: https://mesa.readthedocs.io/en/stable

1 May 15, 2022
Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs This is an implemetation of the paper Few-shot Relation Extraction via Baye

MilaGraph 36 Nov 22, 2022
[CVPR 2021] Unsupervised Degradation Representation Learning for Blind Super-Resolution

DASR Pytorch implementation of "Unsupervised Degradation Representation Learning for Blind Super-Resolution", CVPR 2021 [arXiv] Overview Requirements

Longguang Wang 318 Dec 24, 2022
Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective

Unofficial pytorch implementation of the paper "Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective"

16 Nov 21, 2022
This is code to fit per-pixel environment map with spherical Gaussian lobes, using LBFGS optimization

Spherical Gaussian Optimization This is code to fit per-pixel environment map with spherical Gaussian lobes, using LBFGS optimization. This code has b

41 Dec 14, 2022
Learning to Estimate Hidden Motions with Global Motion Aggregation

Learning to Estimate Hidden Motions with Global Motion Aggregation (GMA) This repository contains the source code for our paper: Learning to Estimate

Shihao Jiang (Zac) 221 Dec 18, 2022
Implementation of Heterogeneous Graph Attention Network

HetGAN Implementation of Heterogeneous Graph Attention Network This is the code repository of paper "Prediction of Metro Ridership During the COVID-19

5 Dec 28, 2021
Unsupervised clustering of high content screen samples

Microscopium Unsupervised clustering and dataset exploration for high content screens. See microscopium in action Public dataset BBBC021 from the Broa

60 Dec 05, 2022
simple artificial intelligence utilities

Simple AI Project home: http://github.com/simpleai-team/simpleai This lib implements many of the artificial intelligence algorithms described on the b

921 Dec 08, 2022
Face recognition. Redefined.

FaceFinder Use a powerful CNN to identify faces in images! TABLE OF CONTENTS About The Project Built With Getting Started Prerequisites Installation U

BleepLogger 20 Jun 16, 2021
official implementation for the paper "Simplifying Graph Convolutional Networks"

Simplifying Graph Convolutional Networks Updates As pointed out by #23, there was a subtle bug in our preprocessing code for the reddit dataset. After

Tianyi 727 Jan 01, 2023
Active window border replacement for window managers.

xborder Active window border replacement for window managers. Usage git clone https://github.com/deter0/xborder cd xborder chmod +x xborders ./xborder

deter 250 Dec 30, 2022
Official Pytorch implementation of RePOSE (ICCV2021)

RePOSE: Iterative Rendering and Refinement for 6D Object Detection (ICCV2021) [Link] Abstract We present RePOSE, a fast iterative refinement method fo

Shun Iwase 68 Nov 15, 2022
Accelerated deep learning R&D

Accelerated deep learning R&D PyTorch framework for Deep Learning research and development. It focuses on reproducibility, rapid experimentation, and

Catalyst-Team 3.1k Jan 06, 2023
Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, L

3 Dec 02, 2022
Code for "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks"

Note: this repo has been discontinued, please check code for newer version of the paper here Weight Normalized GAN Code for the paper "On the Effects

Sitao Xiang 182 Sep 06, 2021
Predict stock movement with Machine Learning and Deep Learning algorithms

Project Overview Stock market movement prediction using LSTM Deep Neural Networks and machine learning algorithms Software and Library Requirements Th

Naz Delam 46 Sep 13, 2022
Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).

Relation Prediction as an Auxiliary Training Objective for Knowledge Base Completion This repo provides the code for the paper Relation Prediction as

Facebook Research 85 Jan 02, 2023
📚 Papermill is a tool for parameterizing, executing, and analyzing Jupyter Notebooks.

papermill is a tool for parameterizing, executing, and analyzing Jupyter Notebooks. Papermill lets you: parameterize notebooks execute notebooks This

nteract 5.1k Jan 03, 2023