Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.

Overview

EfficientZero (NeurIPS 2021)

Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.

Environments

EfficientZero requires python3 (>=3.6) and pytorch (>=1.8.0) with the development headers.

We recommend to use torch amp (--amp_type torch_amp) to accelerate training.

Prerequisites

Before starting training, you need to build the c++/cython style external packages.

cd core/ctree
bash make.sh

The distributed framework of this codebase is built on ray.

Installation

As for other packages required for this codebase, please run pip install -r requirements.txt.

Usage

Quick start

  • Train: python main.py --env BreakoutNoFrameskip-v4 --case atari --opr train --amp_type torch_amp --num_gpus 1 --num_cpus 10 --cpu_actor 1 --gpu_actor 1 --force
  • Test: python main.py --env BreakoutNoFrameskip-v4 --case atari --opr test --amp_type torch_amp --num_gpus 1 --load_model --model_path model.p \

Bash file

We provide train.sh and test.sh for training and evaluation.

  • Train:
    • With 4 GPUs (3090): bash train.sh
  • Test: bash test.sh
Required Arguments Description
--env Name of the environment
--case {atari} It's used for switching between different domains(default: atari)
--opr {train,test} select the operation to be performed
--amp_type {torch_amp,none} use torch amp for acceleration
Other Arguments Description
--force will rewrite the result directory
--num_gpus 4 how many GPUs are available
--num_cpus 96 how many CPUs are available
--cpu_actor 14 how many cpu workers
--gpu_actor 20 how many gpu workers
--seed 0 the seed
--use_priority use priority in replay buffer sampling
--use_max_priority use the max priority for the newly collectted data
--amp_type 'torch_amp' use torch amp for acceleration
--info 'EZ-V0' some tags for you experiments
--p_mcts_num 8 set the parallel number of envs in self-play
--revisit_policy_search_rate 0.99 set the rate of reanalyzing policies
--use_root_value use root values in value targets (require more GPU actors)
--render render in evaluation
--save_video save videos for evaluation

Architecture Designs

The architecture of the training pipeline is shown as follows:

Some suggestions

  • To use a smaller model, you can choose smaller dim of the projection layers (Eg: 256/64) and the LSTM hidden layer (Eg: 64) in the config.
  • For GPUs with 10G memory instead of 20G memory, you can allocate 0.25 gpu for each GPU maker (@ray.remote(num_gpus=0.25)) in core/reanalyze_worker.py.

New environment registration

If you wan to apply EfficientZero to a new environment like mujoco. Here are the steps for registration:

  1. Follow the directory config/atari and create dir for the env at config/mujoco.
  2. Implement your MujocoConfig(BaseConfig) class and implement the models as well as your environment wrapper.
  3. Register the case at main.py.

Results

Evaluation with 32 seeds for 3 different runs (different seeds).

Citation

If you find this repo useful, please cite our paper:

@inproceedings{ye2021mastering,
  title={Mastering Atari Games with Limited Data},
  author={Weirui Ye, and Shaohuai Liu, and Thanard Kurutach, and Pieter Abbeel, and Yang Gao},
  booktitle={NeurIPS},
  year={2021}
}

Contact

If you have any question or want to use the code, please contact [email protected] .

Acknowledgement

We appreciate the following github repos a lot for their valuable code base implementations:

https://github.com/koulanurag/muzero-pytorch

https://github.com/werner-duvaud/muzero-general

https://github.com/pytorch/ELF

Owner
Weirui Ye
Weirui Ye
This is an official implementation for "SimMIM: A Simple Framework for Masked Image Modeling".

Project This repo has been populated by an initial template to help get you started. Please make sure to update the content to build a great experienc

Microsoft 674 Dec 26, 2022
Training DiffWave using variational method from Variational Diffusion Models.

Variational DiffWave Training DiffWave using variational method from Variational Diffusion Models. Quick Start python train_distributed.py discrete_10

Chin-Yun Yu 26 Dec 13, 2022
DeRF: Decomposed Radiance Fields

DeRF: Decomposed Radiance Fields Daniel Rebain, Wei Jiang, Soroosh Yazdani, Ke Li, Kwang Moo Yi, Andrea Tagliasacchi Links Paper Project Page Abstract

UBC Computer Vision Group 24 Dec 02, 2022
Detection of drones using their thermal signatures from thermal camera through YOLO-V3 based CNN with modifications to encapsulate drone motion

Drone Detection using Thermal Signature This repository highlights the work for night-time drone detection using a using an Optris PI Lightweight ther

Chong Yu Quan 6 Dec 31, 2022
Breaking the Dilemma of Medical Image-to-image Translation

Breaking the Dilemma of Medical Image-to-image Translation Supervised Pix2Pix and unsupervised Cycle-consistency are two modes that dominate the field

Kid Liet 86 Dec 21, 2022
Material for my PyConDE & PyData Berlin 2022 Talk "5 Steps to Speed Up Your Data-Analysis on a Single Core"

5 Steps to Speed Up Your Data-Analysis on a Single Core Material for my talk at the PyConDE & PyData Berlin 2022 Description Your data analysis pipeli

Jonathan Striebel 9 Dec 12, 2022
Source code for CAST - Crisis Domain Adaptation Using Sequence-to-sequence Transformers (Accepted to ISCRAM 2021, CorePaper).

Source code for CAST: Crisis Domain Adaptation UsingSequence-to-sequenceTransformers (Paper, BibTeX, Accepted to ISCRAM 2021, CorePaper) Quick start D

Congcong Wang 0 Jul 14, 2021
A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation

##A tensorflow implementation of Fully Convolutional Networks For Semantic Segmentation. #USAGE To run the trained classifier on some images: python w

Alex Seewald 13 Nov 17, 2022
[EMNLP 2021] MuVER: Improving First-Stage Entity Retrieval with Multi-View Entity Representations

MuVER This repo contains the code and pre-trained model for our EMNLP 2021 paper: MuVER: Improving First-Stage Entity Retrieval with Multi-View Entity

24 May 30, 2022
A new benchmark for Icon Question Answering (IconQA) and a large-scale icon dataset Icon645.

IconQA About IconQA is a new diverse abstract visual question answering dataset that highlights the importance of abstract diagram understanding and c

Pan Lu 24 Dec 30, 2022
Construct a neural network frame by Numpy

本项目的CSDN博客链接:https://blog.csdn.net/weixin_41578567/article/details/111482022 1. 概览 本项目主要用于神经网络的学习,通过基于numpy的实现,了解神经网络底层前向传播、反向传播以及各类优化器的原理。 该项目目前已实现的功

24 Jan 22, 2022
Codes for realizing theories learned from Data Mining, Machine Learning, Deep Learning without using the present Python packages.

Codes-for-Algorithms Codes for realizing theories learned from Data Mining, Machine Learning, Deep Learning without using the present Python packages.

Tracy (Shengmin) Tao 1 Apr 12, 2022
Official PyTorch implementation of "Improving Face Recognition with Large AgeGaps by Learning to Distinguish Children" (BMVC 2021)

Inter-Prototype (BMVC 2021): Official Project Webpage This repository provides the official PyTorch implementation of the following paper: Improving F

Jungsoo Lee 16 Jun 30, 2022
Code base of object detection

rmdet code base of object detection. 环境安装: 1. 安装conda python环境 - `conda create -n xxx python=3.7/3.8` - `conda activate xxx` 2. 运行脚本,自动安装pytorch1

3 Mar 08, 2022
The official implementation of the CVPR2021 paper: Decoupled Dynamic Filter Networks

Decoupled Dynamic Filter Networks This repo is the official implementation of CVPR2021 paper: "Decoupled Dynamic Filter Networks". Introduction DDF is

F.S.Fire 180 Dec 30, 2022
Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction

Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction Official github repository for the paper High Fidelity De

28 Dec 16, 2022
StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators

StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators [Project Website] [Replicate.ai Project] StyleGAN-NADA: CLIP-Guided Domain Adaptation

992 Dec 30, 2022
Python implementation of ADD: Frequency Attention and Multi-View based Knowledge Distillation to Detect Low-Quality Compressed Deepfake Images, AAAI2022.

ADD: Frequency Attention and Multi-View based Knowledge Distillation to Detect Low-Quality Compressed Deepfake Images Binh M. Le & Simon S. Woo, "ADD:

2 Oct 24, 2022
Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction

Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction Requirements The code has been tested running under Python 3.7.4, with the foll

zshicode 84 Jan 01, 2023
Generative Modelling of BRDF Textures from Flash Images [SIGGRAPH Asia, 2021]

Neural Material Official code repository for the paper: Generative Modelling of BRDF Textures from Flash Images [SIGGRAPH Asia, 2021] Henzler, Deschai

Philipp Henzler 80 Dec 20, 2022