Contrastive Learning of Structured World Models

Related tags

Deep Learningc-swm
Overview

Contrastive Learning of Structured World Models

This repository contains the official PyTorch implementation of:

Contrastive Learning of Structured World Models.
Thomas Kipf, Elise van der Pol, Max Welling.
http://arxiv.org/abs/1911.12247

C-SWM

Abstract: A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.

Requirements

  • Python 3.6 or 3.7
  • PyTorch version 1.2
  • OpenAI Gym version: 0.12.0 pip install gym==0.12.0
  • OpenAI Atari_py version: 0.1.4: pip install atari-py==0.1.4
  • Scikit-image version 0.15.0 pip install scikit-image==0.15.0
  • Matplotlib version 3.0.2 pip install matplotlib==3.0.2

Generate datasets

2D Shapes:

python data_gen/env.py --env_id ShapesTrain-v0 --fname data/shapes_train.h5 --num_episodes 1000 --seed 1
python data_gen/env.py --env_id ShapesEval-v0 --fname data/shapes_eval.h5 --num_episodes 10000 --seed 2

3D Cubes:

python data_gen/env.py --env_id CubesTrain-v0 --fname data/cubes_train.h5 --num_episodes 1000 --seed 3
python data_gen/env.py --env_id CubesEval-v0 --fname data/cubes_eval.h5 --num_episodes 10000 --seed 4

Atari Pong:

python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_eval.h5 --num_episodes 100 --atari --seed 2

Space Invaders:

python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_eval.h5 --num_episodes 100 --atari --seed 2

3-Body Gravitational Physics:

python data_gen/physics.py --num-episodes 5000 --fname data/balls_train.h5 --seed 1
python data_gen/physics.py --num-episodes 1000 --fname data/balls_eval.h5 --eval --seed 2

Run model training and evaluation

2D Shapes:

python train.py --dataset data/shapes_train.h5 --encoder small --name shapes
python eval.py --dataset data/shapes_eval.h5 --save-folder checkpoints/shapes --num-steps 1

3D Cubes:

python train.py --dataset data/cubes_train.h5 --encoder large --name cubes
python eval.py --dataset data/cubes_eval.h5 --save-folder checkpoints/cubes --num-steps 1

Atari Pong:

python train.py --dataset data/pong_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name pong
python eval.py --dataset data/pong_eval.h5 --save-folder checkpoints/pong --num-steps 1

Space Invaders:

python train.py --dataset data/spaceinvaders_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name spaceinvaders
python eval.py --dataset data/spaceinvaders_eval.h5 --save-folder checkpoints/spaceinvaders --num-steps 1

3-Body Gravitational Physics:

python train.py --dataset data/balls_train.h5 --encoder medium --embedding-dim 4 --num-objects 3 --ignore-action --name balls
python eval.py --dataset data/balls_eval.h5 --save-folder checkpoints/balls --num-steps 1

Cite

If you make use of this code in your own work, please cite our paper:

@article{kipf2019contrastive,
  title={Contrastive Learning of Structured World Models}, 
  author={Kipf, Thomas and van der Pol, Elise and Welling, Max}, 
  journal={arXiv preprint arXiv:1911.12247}, 
  year={2019} 
}
Owner
Thomas Kipf
Thomas Kipf
SBINN: Systems-biology informed neural network

SBINN: Systems-biology informed neural network The source code for the paper M. Daneker, Z. Zhang, G. E. Karniadakis, & L. Lu. Systems biology: Identi

Lu Group 15 Nov 19, 2022
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
A customisable game where you have to quickly click on black tiles in order of appearance while avoiding clicking on white squares.

W.I.P-Aim-Memory-Game A customisable game where you have to quickly click on black tiles in order of appearance while avoiding clicking on white squar

dE_soot 1 Dec 08, 2021
Points2Surf: Learning Implicit Surfaces from Point Clouds (ECCV 2020 Spotlight)

Points2Surf: Learning Implicit Surfaces from Point Clouds (ECCV 2020 Spotlight)

Philipp Erler 329 Jan 06, 2023
Dyalog-apl-docset - Dyalog APL Dash Docset Generator

Dyalog APL Dash Docset Generator o alasa e kili sona kepeken tenpo lili a A Dash

Maciej Goszczycki 1 Jan 10, 2022
Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral)

Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral) Tianyu Wang*, Xiaowei Hu*, Chi-Wing Fu, and Pheng-Ann Hen

Steve Wong 51 Oct 20, 2022
Head and Neck Tumour Segmentation and Prediction of Patient Survival Project

Head-and-Neck-Tumour-Segmentation-and-Prediction-of-Patient-Survival Welcome to the Head and Neck Tumour Segmentation and Prediction of Patient Surviv

5 Oct 20, 2022
Official tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”

Tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”.

3.7k Dec 31, 2022
Code for paper "Multi-level Disentanglement Graph Neural Network"

Multi-level Disentanglement Graph Neural Network (MD-GNN) This is a PyTorch implementation of the MD-GNN, and the code includes the following modules:

Lirong Wu 6 Dec 29, 2022
Official repository of OFA. Paper: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework

Paper | Blog OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks (e.g., image gene

OFA Sys 1.4k Jan 08, 2023
Yolo Traffic Light Detection With Python

Yolo-Traffic-Light-Detection This project is based on detecting the Traffic light. Pretained data is used. This application entertained both real time

Ananta Raj Pant 2 Aug 08, 2022
Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-like Documents.

Value Retrieval with Arbitrary Queries for Form-like Documents Introduction Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-

Salesforce 13 Sep 15, 2022
Mixed Transformer UNet for Medical Image Segmentation

MT-UNet Update 2021/11/19 Thank you for your interest in our work. We have uploaded the code of our MTUNet to help peers conduct further research on i

dotman 92 Dec 25, 2022
IDRLnet, a Python toolbox for modeling and solving problems through Physics-Informed Neural Network (PINN) systematically.

IDRLnet IDRLnet is a machine learning library on top of PyTorch. Use IDRLnet if you need a machine learning library that solves both forward and inver

IDRL 105 Dec 17, 2022
TransPrompt - Towards an Automatic Transferable Prompting Framework for Few-shot Text Classification

TransPrompt This code is implement for our EMNLP 2021's paper 《TransPrompt:Towards an Automatic Transferable Prompting Framework for Few-shot Text Cla

WangJianing 23 Dec 21, 2022
Official repository for "On Generating Transferable Targeted Perturbations" (ICCV 2021)

On Generating Transferable Targeted Perturbations (ICCV'21) Muzammal Naseer, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Fatih Porikli Paper:

Muzammal Naseer 46 Nov 17, 2022
Efficient 3D Backbone Network for Temporal Modeling

VoV3D is an efficient and effective 3D backbone network for temporal modeling implemented on top of PySlowFast. Diverse Temporal Aggregation and

102 Dec 06, 2022
Facial recognition project

Facial recognition project documentation Project introduction This project is developed by linuxu. It is a face model recognition project developed ba

Jefferson 2 Dec 04, 2022
An image processing project uses Viola-jones technique to detect faces and then use SIFT algorithm for recognition.

Attendance_System An image processing project uses Viola-jones technique to detect faces and then use LPB algorithm for recognition. Face Detection Us

8 Jan 11, 2022
Multi-agent reinforcement learning algorithm and environment

Multi-agent reinforcement learning algorithm and environment [en/cn] Pytorch implements multi-agent reinforcement learning algorithms including IQL, Q

万鲲鹏 7 Sep 20, 2022