Contrastive Learning of Structured World Models

Related tags

Deep Learningc-swm
Overview

Contrastive Learning of Structured World Models

This repository contains the official PyTorch implementation of:

Contrastive Learning of Structured World Models.
Thomas Kipf, Elise van der Pol, Max Welling.
http://arxiv.org/abs/1911.12247

C-SWM

Abstract: A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.

Requirements

  • Python 3.6 or 3.7
  • PyTorch version 1.2
  • OpenAI Gym version: 0.12.0 pip install gym==0.12.0
  • OpenAI Atari_py version: 0.1.4: pip install atari-py==0.1.4
  • Scikit-image version 0.15.0 pip install scikit-image==0.15.0
  • Matplotlib version 3.0.2 pip install matplotlib==3.0.2

Generate datasets

2D Shapes:

python data_gen/env.py --env_id ShapesTrain-v0 --fname data/shapes_train.h5 --num_episodes 1000 --seed 1
python data_gen/env.py --env_id ShapesEval-v0 --fname data/shapes_eval.h5 --num_episodes 10000 --seed 2

3D Cubes:

python data_gen/env.py --env_id CubesTrain-v0 --fname data/cubes_train.h5 --num_episodes 1000 --seed 3
python data_gen/env.py --env_id CubesEval-v0 --fname data/cubes_eval.h5 --num_episodes 10000 --seed 4

Atari Pong:

python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_eval.h5 --num_episodes 100 --atari --seed 2

Space Invaders:

python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_eval.h5 --num_episodes 100 --atari --seed 2

3-Body Gravitational Physics:

python data_gen/physics.py --num-episodes 5000 --fname data/balls_train.h5 --seed 1
python data_gen/physics.py --num-episodes 1000 --fname data/balls_eval.h5 --eval --seed 2

Run model training and evaluation

2D Shapes:

python train.py --dataset data/shapes_train.h5 --encoder small --name shapes
python eval.py --dataset data/shapes_eval.h5 --save-folder checkpoints/shapes --num-steps 1

3D Cubes:

python train.py --dataset data/cubes_train.h5 --encoder large --name cubes
python eval.py --dataset data/cubes_eval.h5 --save-folder checkpoints/cubes --num-steps 1

Atari Pong:

python train.py --dataset data/pong_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name pong
python eval.py --dataset data/pong_eval.h5 --save-folder checkpoints/pong --num-steps 1

Space Invaders:

python train.py --dataset data/spaceinvaders_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name spaceinvaders
python eval.py --dataset data/spaceinvaders_eval.h5 --save-folder checkpoints/spaceinvaders --num-steps 1

3-Body Gravitational Physics:

python train.py --dataset data/balls_train.h5 --encoder medium --embedding-dim 4 --num-objects 3 --ignore-action --name balls
python eval.py --dataset data/balls_eval.h5 --save-folder checkpoints/balls --num-steps 1

Cite

If you make use of this code in your own work, please cite our paper:

@article{kipf2019contrastive,
  title={Contrastive Learning of Structured World Models}, 
  author={Kipf, Thomas and van der Pol, Elise and Welling, Max}, 
  journal={arXiv preprint arXiv:1911.12247}, 
  year={2019} 
}
Owner
Thomas Kipf
Thomas Kipf
This is an official implementation of "Polarized Self-Attention: Towards High-quality Pixel-wise Regression"

Polarized Self-Attention: Towards High-quality Pixel-wise Regression This is an official implementation of: Huajun Liu, Fuqiang Liu, Xinyi Fan and Don

DeLightCMU 212 Jan 08, 2023
The code for Expectation-Maximization Attention Networks for Semantic Segmentation (ICCV'2019 Oral)

EMANet News The bug in loading the pretrained model is now fixed. I have updated the .pth. To use it, download it again. EMANet-101 gets 80.99 on the

Xia Li 李夏 663 Nov 30, 2022
Keras Realtime Multi-Person Pose Estimation - Keras version of Realtime Multi-Person Pose Estimation project

This repository has become incompatible with the latest and recommended version of Tensorflow 2.0 Instead of refactoring this code painfully, I create

M Faber 769 Dec 08, 2022
PyTorch Implementation of Fully Convolutional Networks. (Training code to reproduce the original result is available.)

pytorch-fcn PyTorch implementation of Fully Convolutional Networks. Requirements pytorch = 0.2.0 torchvision = 0.1.8 fcn = 6.1.5 Pillow scipy tqdm

Kentaro Wada 1.6k Jan 07, 2023
Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis

Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis [Paper] [Online Demo] The following results are obtained by our SCUNet with purely syn

Kai Zhang 312 Jan 07, 2023
Accommodating supervised learning algorithms for the historical prices of the world's favorite cryptocurrency and boosting it through LightGBM.

Accommodating supervised learning algorithms for the historical prices of the world's favorite cryptocurrency and boosting it through LightGBM.

1 Nov 27, 2021
Official Implementation of Few-shot Visual Relationship Co-localization

VRC Official implementation of the Few-shot Visual Relationship Co-localization (ICCV 2021) paper project page | paper Requirements Use python = 3.8.

22 Oct 13, 2022
[CVPR2021] DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datasets

DoDNet This repo holds the pytorch implementation of DoDNet: DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datase

116 Dec 12, 2022
graph-theoretic framework for robust pairwise data association

CLIPPER: A Graph-Theoretic Framework for Robust Data Association Data association is a fundamental problem in robotics and autonomy. CLIPPER provides

MIT Aerospace Controls Laboratory 118 Dec 28, 2022
AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation

AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation AniGAN: Style-Guided Generative Adversarial Networks for U

Bing Li 81 Dec 14, 2022
Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment

PENecro This project is based on "Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment", published on hardwear.io USA 202

Ta-Lun Yen 10 May 17, 2022
Team Enigma at ArgMining 2021 Shared Task: Leveraging Pretrained Language Models for Key Point Matching

Team Enigma at ArgMining 2021 Shared Task: Leveraging Pretrained Language Models for Key Point Matching This is our attempt of the shared task on Quan

Manav Nitin Kapadnis 12 Jul 08, 2022
Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics.

Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics. By Andres Milioto @ University of Bonn. (for the new P

Photogrammetry & Robotics Bonn 314 Dec 30, 2022
SIMULEVAL A General Evaluation Toolkit for Simultaneous Translation

SimulEval SimulEval is a general evaluation framework for simultaneous translation on text and speech. Requirement python = 3.7.0 Installation git cl

Facebook Research 48 Dec 28, 2022
Breaking the Dilemma of Medical Image-to-image Translation

Breaking the Dilemma of Medical Image-to-image Translation Supervised Pix2Pix and unsupervised Cycle-consistency are two modes that dominate the field

Kid Liet 86 Dec 21, 2022
Blender Python - Node-based multi-line text and image flowchart

MindMapper v0.8 Node-based text and image flowchart for Blender Mindmap with shortcuts visible: Mindmap with shortcuts hidden: Notes This was requeste

SpectralVectors 58 Oct 08, 2022
(IEEE TIP 2021) Regularized Densely-connected Pyramid Network for Salient Instance Segmentation

RDPNet IEEE TIP 2021: Regularized Densely-connected Pyramid Network for Salient Instance Segmentation PyTorch training and testing code are available.

Yu-Huan Wu 41 Oct 21, 2022
[EMNLP 2021] MuVER: Improving First-Stage Entity Retrieval with Multi-View Entity Representations

MuVER This repo contains the code and pre-trained model for our EMNLP 2021 paper: MuVER: Improving First-Stage Entity Retrieval with Multi-View Entity

24 May 30, 2022
Inference pipeline for our participation in the FeTA challenge 2021.

feta-inference Inference pipeline for our participation in the FeTA challenge 2021. Team name: TRABIT Installation Download the two folders in https:/

Lucas Fidon 2 Apr 13, 2022
Official Implementation (PyTorch) of "Point Cloud Augmentation with Weighted Local Transformations", ICCV 2021

PointWOLF: Point Cloud Augmentation with Weighted Local Transformations This repository is the implementation of PointWOLF(To appear). Sihyeon Kim1*,

MLV Lab (Machine Learning and Vision Lab at Korea University) 16 Nov 03, 2022