Train an RL agent to execute natural language instructions in a 3D Environment (PyTorch)

Overview

Gated-Attention Architectures for Task-Oriented Language Grounding

This is a PyTorch implementation of the AAAI-18 paper:

Gated-Attention Architectures for Task-Oriented Language Grounding
Devendra Singh Chaplot, Kanthashree Mysore Sathyendra, Rama Kumar Pasumarthi, Dheeraj Rajagopal, Ruslan Salakhutdinov
Carnegie Mellon University

Project Website: https://sites.google.com/view/gated-attention

example

This repository contains:

  • Code for training an A3C-LSTM agent using Gated-Attention
  • Code for Doom-based language grounding environment

Dependencies

(We recommend using Anaconda)

Usage

Using the Environment

For running a random agent:

python env_test.py

To play in the environment:

python env_test.py --interactive 1

To change the difficulty of the environment (easy/medium/hard):

python env_test.py -d easy

Training Gated-Attention A3C-LSTM agent

For training a A3C-LSTM agent with 32 threads:

python a3c_main.py --num-processes 32 --evaluate 0

The code will save the best model at ./saved/model_best.

To the test the pre-trained model for Multitask Generalization:

python a3c_main.py --evaluate 1 --load saved/pretrained_model

To the test the pre-trained model for Zero-shot Task Generalization:

python a3c_main.py --evaluate 2 --load saved/pretrained_model

To the visualize the model while testing add '--visualize 1':

python a3c_main.py --evaluate 2 --load saved/pretrained_model --visualize 1

To test the trained model, use --load saved/model_best in the above commands.

All arguments for a3c_main.py:

  -h, --help            show this help message and exit
  -l MAX_EPISODE_LENGTH, --max-episode-length MAX_EPISODE_LENGTH
                        maximum length of an episode (default: 30)
  -d DIFFICULTY, --difficulty DIFFICULTY
                        Difficulty of the environment, "easy", "medium" or
                        "hard" (default: hard)
  --living-reward LIVING_REWARD
                        Default reward at each time step (default: 0, change
                        to -0.005 to encourage shorter paths)
  --frame-width FRAME_WIDTH
                        Frame width (default: 300)
  --frame-height FRAME_HEIGHT
                        Frame height (default: 168)
  -v VISUALIZE, --visualize VISUALIZE
                        Visualize the envrionment (default: 0, use 0 for
                        faster training)
  --sleep SLEEP         Sleep between frames for better visualization
                        (default: 0)
  --scenario-path SCENARIO_PATH
                        Doom scenario file to load (default: maps/room.wad)
  --interactive INTERACTIVE
                        Interactive mode enables human to play (default: 0)
  --all-instr-file ALL_INSTR_FILE
                        All instructions file (default:
                        data/instructions_all.json)
  --train-instr-file TRAIN_INSTR_FILE
                        Train instructions file (default:
                        data/instructions_train.json)
  --test-instr-file TEST_INSTR_FILE
                        Test instructions file (default:
                        data/instructions_test.json)
  --object-size-file OBJECT_SIZE_FILE
                        Object size file (default: data/object_sizes.txt)
  --lr LR               learning rate (default: 0.001)
  --gamma G             discount factor for rewards (default: 0.99)
  --tau T               parameter for GAE (default: 1.00)
  --seed S              random seed (default: 1)
  -n N, --num-processes N
                        how many training processes to use (default: 4)
  --num-steps NS        number of forward steps in A3C (default: 20)
  --load LOAD           model path to load, 0 to not reload (default: 0)
  -e EVALUATE, --evaluate EVALUATE
                        0:Train, 1:Evaluate MultiTask Generalization
                        2:Evaluate Zero-shot Generalization (default: 0)
  --dump-location DUMP_LOCATION
                        path to dump models and log (default: ./saved/)

Demostration videos:

Multitask Generalization video: https://www.youtube.com/watch?v=YJG8fwkv7gA

Zero-shot Task Generalization video: https://www.youtube.com/watch?v=JziCKsLrudE

Different stages of training: https://www.youtube.com/watch?v=o_G6was03N0

Cite as

Chaplot, D.S., Sathyendra, K.M., Pasumarthi, R.K., Rajagopal, D. and Salakhutdinov, R., 2017. Gated-Attention Architectures for Task-Oriented Language Grounding. arXiv preprint arXiv:1706.07230. (PDF)

Bibtex:

@article{chaplot2017gated,
  title={Gated-Attention Architectures for Task-Oriented Language Grounding},
  author={Chaplot, Devendra Singh and Sathyendra, Kanthashree Mysore and Pasumarthi, Rama Kumar and Rajagopal, Dheeraj and Salakhutdinov, Ruslan},
  journal={arXiv preprint arXiv:1706.07230},
  year={2017}
}

Acknowledgements

This repository uses ViZDoom API (https://github.com/mwydmuch/ViZDoom) and parts of the code from the API. The implementation of A3C is borrowed from https://github.com/ikostrikov/pytorch-a3c. The poisson-disc code is borrowed from https://github.com/IHautaI/poisson-disc.

Owner
Devendra Chaplot
Ph.D. student in Machine Learning Dept., School of Computer Science, CMU.
Devendra Chaplot
Robust Lane Detection via Expanded Self Attention (WACV 2022)

Robust Lane Detection via Expanded Self Attention (WACV 2022) Minhyeok Lee, Junhyeop Lee, Dogyoon Lee, Woojin Kim, Sangwon Hwang, Sangyoun Lee Overvie

Min Hyeok Lee 18 Nov 12, 2022
Real-Time SLAM for Monocular, Stereo and RGB-D Cameras, with Loop Detection and Relocalization Capabilities

ORB-SLAM2 Authors: Raul Mur-Artal, Juan D. Tardos, J. M. M. Montiel and Dorian Galvez-Lopez (DBoW2) 13 Jan 2017: OpenCV 3 and Eigen 3.3 are now suppor

Raul Mur-Artal 7.8k Dec 30, 2022
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
A tf.keras implementation of Facebook AI's MadGrad optimization algorithm

MADGRAD Optimization Algorithm For Tensorflow This package implements the MadGrad Algorithm proposed in Adaptivity without Compromise: A Momentumized,

20 Aug 18, 2022
Data, notebooks, and articles associated with the RSNA AI Deep Learning Lab at RSNA 2021

RSNA AI Deep Learning Lab 2021 Intro Welcome Deep Learners! This document provides all the information you need to participate in the RSNA AI Deep Lea

RSNA 65 Dec 16, 2022
Kaggle G2Net Gravitational Wave Detection : 2nd place solution

Kaggle G2Net Gravitational Wave Detection : 2nd place solution

Hiroshechka Y 33 Dec 26, 2022
[NeurIPS 2021] Code for Unsupervised Learning of Compositional Energy Concepts

Unsupervised Learning of Compositional Energy Concepts This is the pytorch code for the paper Unsupervised Learning of Compositional Energy Concepts.

45 Nov 30, 2022
This repository contains the code and models for the following paper.

DC-ShadowNet Introduction This is an implementation of the following paper DC-ShadowNet: Single-Image Hard and Soft Shadow Removal Using Unsupervised

AuAgCu 65 Dec 27, 2022
[CVPR 2021] MiVOS - Scribble to Mask module

MiVOS (CVPR 2021) - Scribble To Mask Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [arXiv] [Paper PDF] [Project Page] A simplistic network that turns scri

Rex Cheng 65 Dec 22, 2022
A PyTorch Toolbox for Face Recognition

FaceX-Zoo FaceX-Zoo is a PyTorch toolbox for face recognition. It provides a training module with various supervisory heads and backbones towards stat

JDAI-CV 1.6k Jan 06, 2023
Instance Semantic Segmentation List

Instance Semantic Segmentation List This repository contains lists of state-or-art instance semantic segmentation works. Papers and resources are list

bighead 87 Mar 06, 2022
Autotype on websites that have copy-paste disabled like Moodle, HackerEarth contest etc.

Autotype A quick and small python script that helps you autotype on websites that have copy paste disabled like Moodle, HackerEarth contests etc as it

Tushar 32 Nov 03, 2022
DEEPAGÉ: Answering Questions in Portuguese about the Brazilian Environment

DEEPAGÉ: Answering Questions in Portuguese about the Brazilian Environment This repository is related to the paper DEEPAGÉ: Answering Questions in Por

0 Dec 10, 2021
Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted)

NLOS-OT Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted) Description In this reposit

Ruixu Geng(耿瑞旭) 16 Dec 16, 2022
Running Google MoveNet Multipose Tracking models on OpenVINO.

MoveNet MultiPose Tracking on OpenVINO

60 Nov 17, 2022
A minimalist tool to display a network graph.

A tool to get a minimalist view of any architecture This tool has only be tested with the models included in this repo. Therefore, I can't guarantee t

Thibault Castells 1 Feb 11, 2022
Beyond Image to Depth: Improving Depth Prediction using Echoes (CVPR 2021)

Beyond Image to Depth: Improving Depth Prediction using Echoes (CVPR 2021) Kranti Kumar Parida, Siddharth Srivastava, Gaurav Sharma. We address the pr

Kranti Kumar Parida 33 Jun 27, 2022
A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs.

PYGON A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs. Installation This code requires to install and run the graph

Yoram Louzoun's Lab 0 Jun 25, 2021
[ICML 2020] Prediction-Guided Multi-Objective Reinforcement Learning for Continuous Robot Control

PG-MORL This repository contains the implementation for the paper Prediction-Guided Multi-Objective Reinforcement Learning for Continuous Robot Contro

MIT Graphics Group 65 Jan 07, 2023
An implementation of chunked, compressed, N-dimensional arrays for Python.

Zarr Latest Release Package Status License Build Status Coverage Downloads Gitter Citation What is it? Zarr is a Python package providing an implement

Zarr Developers 1.1k Dec 30, 2022