Pretraining Representations For Data-Efficient Reinforcement Learning

Related tags

Deep LearningSGI
Overview

Pretraining Representations For Data-Efficient Reinforcement Learning

Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Charlin, Devon Hjelm, Philip Bachman & Aaron Courville

This repo provides code for implementing SGI.

Install

To install the requirements, follow these steps:

# PyTorch
export LANG=C.UTF-8
# Install requirements
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt

# Finally, install the project
pip install --user -e .

Usage:

The default branch for the latest and stable changes is release.

  • To run SGI:
  1. Download the DQN replay dataset from https://research.google/tools/datasets/dqn-replay/
    • Or substitute your own pre-training data! The codebase expects a series of .gz files, one each for observations, actions and terminals.
  2. To pretrain with SGI:
python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \
    env.game=pong seed=1 offline_model_save={your model name} \
    offline.runner.epochs=10 offline.runner.dataloader.games=[Pong] \
    offline.runner.no_eval=1 \
    +offline.algo.goal_weight=1 \
    +offline.algo.inverse_model_weight=1 \
    +offline.algo.spr_weight=1 \
    +offline.algo.target_update_tau=0.01 \
    +offline.agent.model_kwargs.momentum_tau=0.01 \
    do_online=False \
    algo.batch_size=256 \
    +offline.agent.model_kwargs.noisy_nets_std=0 \
    offline.runner.dataloader.dataset_on_disk=True \
    offline.runner.dataloader.samples=1000000 \
    offline.runner.dataloader.checkpoints='{your checkpoints}' \
    offline.runner.dataloader.num_workers=2 \
    offline.runner.dataloader.data_path={your data dir} \
    offline.runner.dataloader.tmp_data_path=./ 
  1. To fine-tune with SGI:
python -m scripts.run public=True env.game=pong seed=1 num_logs=10  \
    model_load={your_model_name} model_folder=./ \
    algo.encoder_lr=0.000001 algo.q_l1_lr=0.00003 algo.clip_grad_norm=-1 algo.clip_model_grad_norm=-1

When reporting scores, we average across 10 fine-tuning seeds.

./scripts/experiments contains a number of example configurations, including for SGI-M, SGI-M/L and SGI-W, for both pre-training and fine-tuning. Each of these scripts can be launched by providing a game and seed, e.g., ./scripts/experiments/sgim_pretrain.sh pong 1. These scripts are provided primarily to illustrate the hyperparameters used for different experiments; you will likely need to modify the arguments in these scripts to point to your data and model directories.

Data for SGI-R and SGI-E is not included due to its size, but can be re-generated locally. Contact us for details.

What does each file do?

.
β”œβ”€β”€ scripts
β”‚   β”œβ”€β”€ run.py                # The main runner script to launch jobs.
β”‚   β”œβ”€β”€ config.yaml           # The hydra configuration file, listing hyperparameters and options.
|   └── experiments           # Configurations for various experiments done by SGI.
|   
β”œβ”€β”€ src                     
β”‚   β”œβ”€β”€ agent.py              # Implements the Agent API for action selection 
β”‚   β”œβ”€β”€ algos.py              # Distributional RL loss and optimization
β”‚   β”œβ”€β”€ models.py             # Forward passes, network initialization.
β”‚   β”œβ”€β”€ networks.py           # Network architecture and forward passes.
β”‚   β”œβ”€β”€ offline_dataset.py    # Dataloader for offline data.
β”‚   β”œβ”€β”€ gcrl.py               # Utils for SGI's goal-conditioned RL objective.
β”‚   β”œβ”€β”€ rlpyt_atari_env.py    # Slightly modified Atari env from rlpyt
β”‚   β”œβ”€β”€ rlpyt_utils.py        # Utility methods that we use to extend rlpyt's functionality
β”‚   └── utils.py              # Command line arguments and helper functions 
β”‚
└── requirements.txt          # Dependencies
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition

AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition [ArXiv] [Project Page] This repository is the official implementation of AdaMML:

International Business Machines 43 Dec 26, 2022
Transfer-Learn is an open-source and well-documented library for Transfer Learning.

Transfer-Learn is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consist

THUML @ Tsinghua University 2.2k Jan 03, 2023
Code for How To Create A Fully Automated AI Based Trading System WithΒ Python

AI Based Trading System This code works as a boilerplate for an AI based trading system with yfinance as data source and RobinHood or Alpaca as broker

RubΓ©n 196 Jan 05, 2023
Localized representation learning from Vision and Text (LoVT)

Localized Vision-Text Pre-Training Contrastive learning has proven effective for pre- training image models on unlabeled data and achieved great resul

Philip MΓΌller 10 Dec 07, 2022
Fully convolutional deep neural network to remove transparent overlays from images

Fully convolutional deep neural network to remove transparent overlays from images

Marc Belmont 1.1k Jan 06, 2023
A Pytorch implement of paper "Anomaly detection in dynamic graphs via transformer" (TADDY).

TADDY: Anomaly detection in dynamic graphs via transformer This repo covers an reference implementation for the paper "Anomaly detection in dynamic gr

Yue Tan 21 Nov 24, 2022
A python program to hack instagram

hackinsta a program to hack instagram Yokoback_(instahack) is the file to open, you need libraries write on import. You run that file in the same fold

2 Jan 22, 2022
Robust & Reliable Route Recommendation on Road Networks

NeuroMLR: Robust & Reliable Route Recommendation on Road Networks This repository is the official implementation of NeuroMLR: Robust & Reliable Route

4 Dec 20, 2022
Exploration & Research into cross-domain MEV. Initial focus on ETH/POLYGON.

xMEV, an apt exploration This is a small exploration on the xMEV opportunities between Polygon and Ethereum. It's a data analysis exercise on a few pa

odyslam.eth 7 Oct 18, 2022
Syntax-Aware Action Targeting for Video Captioning

Syntax-Aware Action Targeting for Video Captioning Code for SAAT from "Syntax-Aware Action Targeting for Video Captioning" (Accepted to CVPR 2020). Th

59 Oct 13, 2022
A Python library for working with arbitrary-dimension hypercomplex numbers following the Cayley-Dickson construction of algebras.

Hypercomplex A Python library for working with quaternions, octonions, sedenions, and beyond following the Cayley-Dickson construction of hypercomplex

7 Nov 04, 2022
ESP32 python application to read data from a Tiltβ„’ Hydrometer for homebrewing

TitlESP32 ESP32 MicroPython application to read and log data from a Tiltβ„’ Hydrometer. Requirements A board with an ESP32 chip USB cable - USB A / micr

IoBeer 5 Dec 01, 2022
The CLRS Algorithmic Reasoning Benchmark

Learning representations of algorithms is an emerging area of machine learning, seeking to bridge concepts from neural networks with classical algorithms.

DeepMind 251 Jan 05, 2023
Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021)

EMI-FGSM This repository contains code to reproduce results from the paper: Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021) Xiaosen Wa

John Hopcroft Lab at HUST 10 Sep 26, 2022
Long Expressive Memory (LEM)

Long Expressive Memory for Sequence Modeling This repository contains the implementation to reproduce the numerical experiments of the paper Long Expr

Konstantin Rusch 47 Dec 17, 2022
This is the official repository of Music Playlist Title Generation: A Machine-Translation Approach.

PlyTitle_Generation This is the official repository of Music Playlist Title Generation: A Machine-Translation Approach. The paper has been accepted by

SeungHeonDoh 6 Jan 03, 2022
Recommendation algorithms for large graphs

Fast recommendation algorithms for large graphs based on link analysis. License: Apache Software License Author: Emmanouil (Manios) Krasanakis Depende

Multimedia Knowledge and Social Analytics Lab 27 Jan 07, 2023
Collection of common code that's shared among different research projects in FAIR computer vision team.

fvcore fvcore is a light-weight core library that provides the most common and essential functionality shared in various computer vision frameworks de

Meta Research 1.5k Jan 07, 2023
PyTorch implementation of Deformable Convolution

Deformable Convolutional Networks in PyTorch This repo is an implementation of Deformable Convolution. Ported from author's MXNet implementation. Buil

411 Dec 16, 2022
Create animations for the optimization trajectory of neural nets

Animating the Optimization Trajectory of Neural Nets loss-landscape-anim lets you create animated optimization path in a 2D slice of the loss landscap

Logan Yang 81 Dec 25, 2022