Additional code for Stable-baselines3 to load and upload models from the Hub.

Overview

Hugging Face x Stable-baselines3

A library to load and upload Stable-baselines3 models from the Hub.

Installation

With pip

Examples

[Todo: add colab tutorial]

Case 1: I want to download a model from the Hub

import gym

from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO

env = gym.make("CartPole-v1")

model = PPO("MlpPolicy", env, verbose=1)

# Retrieve the model from the hub
## repo_id =  id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(repo_id="ThomasSimonini/ppo-CartPole-v1", filename="CartPole-v1")
PPO.load(checkpoint)

obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

env.close()

Case 2: I trained an agent and want to upload it to the Hub

First you need to be logged in to Hugging Face:

  • If you're using Colab/Jupyter Notebooks:
from huggingface_hub import notebook_login
notebook_login()
  • Else:
huggingface-cli login

Then:

import gym
from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO

# Create the environment
env = gym.make('CartPole-v1')

# Define a PPO MLpPolicy architecture
model = PPO('MlpPolicy', env, verbose=1)

# Train it for 10000 timesteps
model.learn(total_timesteps=10000)

# Save the model 
model.save("CartPole-v1")

# Push this saved model to the hf repo
# If this repo does not exists it will be created
## filename: the name of the file == "name" inside model.save("CartPole-v1")
push_to_hub(repo_name = "CartPole-v1",
           organization = "ThomasSimonini",  
           filename = "CartPole-v1", 
           commit_message = "Added Cartpole-v1 trained model")
Comments
  • Environment name normalization and explicit naming schemes

    Environment name normalization and explicit naming schemes

    There was an issue with environment names, that have a slash in their name (see https://github.com/DLR-RM/rl-baselines3-zoo/pull/257). Also the naming scheme for models and repository IDs is just based on convention.

    This PR implements normalization for environment names (replacing slashes with dashes) and encodes the naming scheme for models and repository IDs in little helper classes. The idea is, that those helper classes can be used by downstream libraries to comply with the naming scheme (such as the rl baselines zoo). If we ever need to change the naming scheme or other cases in which the environment name needs to be normalized come up, then we can implement them here and the downstream libraries immediately profit from that.

    I also added a simple smoke test for pulling a model from the hub.

    opened by ernestum 8
  • 400 Client Error for `package_to_hub` function

    400 Client Error for `package_to_hub` function

    I am going through the notebook of Unit 1 of the deep RL course. However, I cannot run the package_to_hub function, which gives the following error:

    HTTPError                                 Traceback (most recent call last)
    
    [<ipython-input-26-97f48e41190b>](https://localhost:8080/#) in <module>
         25                eval_env=eval_env,
         26                repo_id="LorenzoPacchiardi/ppo-LunarLander-v2",
    ---> 27                commit_message="Upload PPO LunarLander-v2 trained agent (50 steps)")
    
    6 frames
    
    [/usr/local/lib/python3.7/dist-packages/requests/models.py](https://localhost:8080/#) in raise_for_status(self)
        939 
        940         if http_error_msg:
    --> 941             raise HTTPError(http_error_msg, response=self)
        942 
        943     def close(self):
    
    HTTPError: 400 Client Error: Bad Request for url: https://huggingface.co/api/models/LorenzoPacchiardi/ppo-LunarLander-v2/commit/main (Request ID: fhQtAuS_qa8bj_c6AI0v5)
    

    I get a similar error with push_to_hub

    I logged in to huggingface correctly with the token, and the load_from_hub function works fine.

    opened by LoryPack 5
  • package_to_hub requires OpenGL and xvfb which are not present on newer Mac OS systems

    package_to_hub requires OpenGL and xvfb which are not present on newer Mac OS systems

    Currently package_to_hub works only for OpenGL capable computers. It doesn't support any other option for generating video and it doesn't allow to upload model without a video. All new Mac OSes don't have OpenGL support any more.

    opened by marcin-sobocinski 3
  • Error installing huggingface_sb3

    Error installing huggingface_sb3

    Hi! I'm running this notebook https://github.com/huggingface/deep-rl-class/blob/main/unit1/unit1.ipynb from your DRL series. Installation of some libraries is causing some issues. For huggingface_sb3, it is:

    Collecting huggingface_sb3
      Using cached huggingface_sb3-2.0.0-py3-none-any.whl (7.4 kB)
    Requirement already satisfied: wasabi in ./rl/lib/python3.8/site-packages (from huggingface_sb3) (0.9.1)
    Collecting cloudpickle==1.6
      Using cached cloudpickle-1.6.0-py3-none-any.whl (23 kB)
    Collecting pyyaml==6.0
      Using cached PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (701 kB)
    Collecting huggingface-hub
      Using cached huggingface_hub-0.6.0-py3-none-any.whl (84 kB)
    Collecting pickle5
      Using cached pickle5-0.0.11.tar.gz (132 kB)
      Preparing metadata (setup.py) ... done
    Collecting typing-extensions>=3.7.4.3
      Using cached typing_extensions-4.2.0-py3-none-any.whl (24 kB)
    Requirement already satisfied: packaging>=20.9 in ./rl/lib/python3.8/site-packages (from huggingface-hub->huggingface_sb3) (21.3)
    Collecting filelock
      Using cached filelock-3.7.0-py3-none-any.whl (10 kB)
    Collecting tqdm
      Using cached tqdm-4.64.0-py2.py3-none-any.whl (78 kB)
    Collecting requests
      Using cached requests-2.27.1-py2.py3-none-any.whl (63 kB)
    Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in ./rl/lib/python3.8/site-packages (from packaging>=20.9->huggingface-hub->huggingface_sb3) (3.0.9)
    Collecting urllib3<1.27,>=1.21.1
      Using cached urllib3-1.26.9-py2.py3-none-any.whl (138 kB)
    Collecting certifi>=2017.4.17
      Using cached certifi-2021.10.8-py2.py3-none-any.whl (149 kB)
    Collecting charset-normalizer~=2.0.0
      Using cached charset_normalizer-2.0.12-py3-none-any.whl (39 kB)
    Collecting idna<4,>=2.5
      Using cached idna-3.3-py3-none-any.whl (61 kB)
    Using legacy 'setup.py install' for pickle5, since package 'wheel' is not installed.
    Installing collected packages: pickle5, certifi, urllib3, typing-extensions, tqdm, pyyaml, idna, filelock, cloudpickle, charset-normalizer, requests, huggingface-hub, huggingface_sb3
      Running setup.py install for pickle5 ... error
      error: subprocess-exited-with-error
      
      × Running setup.py install for pickle5 did not run successfully.
      │ exit code: 1
      ╰─> [27 lines of output]
          running install
          /media/master/support/pip_envs/rl/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
            warnings.warn(
          running build
          running build_py
          creating build
          creating build/lib.linux-x86_64-cpython-38
          creating build/lib.linux-x86_64-cpython-38/pickle5
          copying pickle5/__init__.py -> build/lib.linux-x86_64-cpython-38/pickle5
          copying pickle5/pickle.py -> build/lib.linux-x86_64-cpython-38/pickle5
          copying pickle5/pickletools.py -> build/lib.linux-x86_64-cpython-38/pickle5
          creating build/lib.linux-x86_64-cpython-38/pickle5/test
          copying pickle5/test/pickletester.py -> build/lib.linux-x86_64-cpython-38/pickle5/test
          copying pickle5/test/test_picklebuffer.py -> build/lib.linux-x86_64-cpython-38/pickle5/test
          copying pickle5/test/__init__.py -> build/lib.linux-x86_64-cpython-38/pickle5/test
          copying pickle5/test/test_pickle.py -> build/lib.linux-x86_64-cpython-38/pickle5/test
          running build_ext
          building 'pickle5._pickle' extension
          creating build/temp.linux-x86_64-cpython-38
          creating build/temp.linux-x86_64-cpython-38/pickle5
          x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/media/master/support/pip_envs/rl/include -I/usr/include/python3.8 -c pickle5/_pickle.c -o build/temp.linux-x86_64-cpython-38/pickle5/_pickle.o -std=c99
          In file included from pickle5/_pickle.c:2:
          pickle5/compat.h:1:10: fatal error: Python.h: No such file or directory
              1 | #include "Python.h"
                |          ^~~~~~~~~~
          compilation terminated.
          error: command '/usr/bin/x86_64-linux-gnu-gcc' failed with exit code 1
          [end of output]
      
      note: This error originates from a subprocess, and is likely not a problem with pip.
    error: legacy-install-failure
    
    × Encountered error while trying to install package.
    ╰─> pickle5
    
    note: This is an issue with the package mentioned above, not pip.
    

    I tried installing it with python 3.9 and 3.8 on Ubuntu 22.04 OS.

    Are there any additional requirements to use your library?

    opened by kirilllzaitsev 3
  • generate_replay created a video.mp4 file locally

    generate_replay created a video.mp4 file locally

    This code snippet

    env = VecVideoRecorder(
            eval_env,
            "./",  # Temporary video folder
            record_video_trigger=lambda x: x == 0,
            video_length=video_length,
            name_prefix="",
        )
    

    generated a video file wherever the user is. In the temporary video folder, you can use tempfile.TemporaryDirectory() to automatically create a directory that will be deleted afterwards

    opened by osanseviero 2
  • Rebase repo when pulling

    Rebase repo when pulling

    I think you want to rebase, as we tend to do this in all the mixins in huggingface_hub. I think not having this is what caused this issue: https://github.com/huggingface/deep-rl-class/issues/20.

    opened by nateraw 2
  • Add `VecNormalize` support

    Add `VecNormalize` support

    • add missing type hints
    • fix push_to_hub (bug detected by pytype checker)
    • cleanup
    • add support for VecNormalize

    closes #6

    Demo (and training/loading code): https://huggingface.co/araffin/a2c-Pendulum-v1

    opened by araffin 1
  • Add auto release

    Add auto release

    The behavior with this PR is that once you push a Git tag with v* (usually v1.0.8 for example), which should ideally point to the commit that updates this line https://github.com/huggingface/huggingface_sb3/blob/main/setup.py#L10 (you can push the tag after the commit), it will automatically make a pypi release.

    The only requirement is adding your secret (PYPI_TOKEN_DIST) to the repo settings

    opened by osanseviero 1
  • Allow to pass TensorBoard logs files to package_to_hub

    Allow to pass TensorBoard logs files to package_to_hub

    It's very easy to add TensorBoard logging with SB3, but pushing the files right now needs to be done manually. As an alternative, we could add a param to package_to_hub to pass the logs.

    Related: https://github.com/huggingface/deep-rl-class/pull/19

    opened by osanseviero 0
  • Don't crash when making videos causes problems

    Don't crash when making videos causes problems

    At the moment, if generate_replay fails, the whole package_to_hub method fails. Ideally it would still push the metrics and other related information even if no video is generated

    opened by osanseviero 0
  • Huggingface_SB3 v2.0

    Huggingface_SB3 v2.0

    👋 so here's the SB3 v2.0:

    With our new version we can use package_to_hub method that:

    1. Save the model
    2. Evaluate the model and generate a results.json
    3. Generate a model card
    4. Record a replay video of the agent
    5. Push everything to the hub Here's an example : https://huggingface.co/ThomasSimonini/TEST2-Colab-ppo-LunarLander-v2 (very small training so the agent is bad)

    If you want to try some examples directly on colab I've made a small test colab: https://colab.research.google.com/drive/1FhZ1w7smqPo8GQcW5qb2HmkggZVuok57?usp=sharing

    The PyPi update is also automated thanks to @osanseviero

    A big thanks to Omar who made a lot of tests with the library

    I need a little bit of feedback for the documentation, I think it's not very clear.

    • I explain the 2 cases: downstream and upstream
    • In case 3-4 I explain how to use xvfb if you use colab or vm (because you don't have a screen to render so you can't generate a video without xvfb).

    WDYT? Thanks

    opened by simoninithomas 0
Releases(v2.2.4)
  • v2.2.4(Oct 13, 2022)

    What's Changed

    • Loosen the requirements by @araffin in https://github.com/huggingface/huggingface_sb3/pull/19

    Full Changelog: https://github.com/huggingface/huggingface_sb3/compare/v2.2.3...v2.2.4

    Source code(tar.gz)
    Source code(zip)
  • v2.2.3(Aug 5, 2022)

    Cloudpickle is 1.3 by default on Colab. We need at minimum 1.6 for package_to_hub and load_from_hub to work correctly.

    Full Changelog: https://github.com/huggingface/huggingface_sb3/compare/v2.2.2...v2.2.3

    Source code(tar.gz)
    Source code(zip)
  • v2.2.2(Aug 1, 2022)

    What's Changed

    • V2.2.2 by @simoninithomas in https://github.com/huggingface/huggingface_sb3/pull/17
    • Pinning this dependency leads to some uploading problems. We removed it

    Full Changelog: https://github.com/huggingface/huggingface_sb3/compare/v2.2.1...v2.2.2

    Source code(tar.gz)
    Source code(zip)
  • v2.2.1(Jul 8, 2022)

    What's Changed

    • Notebook fixes by @ernestum in https://github.com/huggingface/huggingface_sb3/pull/12
    • Environment name normalization and explicit naming schemes by @ernestum in https://github.com/huggingface/huggingface_sb3/pull/13
    • Use new upload_folder API by @osanseviero in https://github.com/huggingface/huggingface_sb3/pull/15

    New Contributors

    • @ernestum made their first contribution in https://github.com/huggingface/huggingface_sb3/pull/12

    Full Changelog: https://github.com/huggingface/huggingface_sb3/compare/v2.1.1...v2.2

    Source code(tar.gz)
    Source code(zip)
  • v2.1.0(May 20, 2022)

    What's Changed

    • Use make_vec_env to create envs by @araffin in https://github.com/huggingface/huggingface_sb3/pull/3
    • Rebase repo when pulling by @nateraw in https://github.com/huggingface/huggingface_sb3/pull/7
    • Fix record video for RecurrentPPO by @araffin in https://github.com/huggingface/huggingface_sb3/pull/8
    • Add VecNormalize support by @araffin in https://github.com/huggingface/huggingface_sb3/pull/10

    New Contributors

    • @araffin made their first contribution in https://github.com/huggingface/huggingface_sb3/pull/3
    • @nateraw made their first contribution in https://github.com/huggingface/huggingface_sb3/pull/7

    Full Changelog: https://github.com/huggingface/huggingface_sb3/compare/v2.0.0...v2.1.0

    Source code(tar.gz)
    Source code(zip)
Owner
Hugging Face
The AI community building the future.
Hugging Face
Code release for Local Light Field Fusion at SIGGRAPH 2019

Local Light Field Fusion Project | Video | Paper Tensorflow implementation for novel view synthesis from sparse input images. Local Light Field Fusion

1.1k Dec 27, 2022
A paper using optimal transport to solve the graph matching problem.

GOAT A paper using optimal transport to solve the graph matching problem. https://arxiv.org/abs/2111.05366 Repo structure .github: Files specifying ho

neurodata 8 Jan 04, 2023
Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)

Active Learning for Deep Object Detection via Probabilistic Modeling This repository is the official PyTorch implementation of Active Learning for Dee

NVIDIA Research Projects 130 Jan 06, 2023
FishNet: One Stage to Detect, Segmentation and Pose Estimation

FishNet FishNet: One Stage to Detect, Segmentation and Pose Estimation Introduction In this project, we combine target detection, instance segmentatio

1 Oct 05, 2022
Conversational text Analysis using various NLP techniques

PyConverse Let me try first Installation pip install pyconverse Usage Please try this notebook that demos the core functionalities: basic usage noteb

Rita Anjana 158 Dec 25, 2022
Source code for Acorn, the precision farming rover by Twisted Fields

Acorn precision farming rover This is the software repository for Acorn, the precision farming rover by Twisted Fields. For more information see twist

Twisted Fields 198 Jan 02, 2023
PyTorch implementation for Graph Contrastive Learning with Augmentations

Graph Contrastive Learning with Augmentations PyTorch implementation for Graph Contrastive Learning with Augmentations [poster] [appendix] Yuning You*

Shen Lab at Texas A&M University 382 Dec 15, 2022
Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

FPT_data_centric_competition - Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

Pham Viet Hoang (Harry) 2 Oct 30, 2022
LBK 26 Dec 28, 2022
FairFuzz: AFL extension targeting rare branches

FairFuzz An AFL extension to increase code coverage by targeting rare branches. FairFuzz has a particular advantage on programs with highly nested str

Caroline Lemieux 222 Nov 16, 2022
Hierarchical Time Series Forecasting with a familiar API

scikit-hts Hierarchical Time Series with a familiar API. This is the result from not having found any good implementations of HTS on-line, and my work

Carlo Mazzaferro 204 Dec 17, 2022
Differentiable rasterization applied to 3D model simplification tasks

nvdiffmodeling Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Automatic 3D Model

NVIDIA Research Projects 336 Dec 30, 2022
rliable is an open-source Python library for reliable evaluation, even with a handful of runs, on reinforcement learning and machine learnings benchmarks.

Open-source library for reliable evaluation on reinforcement learning and machine learning benchmarks. See NeurIPS 2021 oral for details.

Google Research 529 Jan 01, 2023
Multiple-criteria decision-making (MCDM) with Electre, Promethee, Weighted Sum and Pareto

EasyMCDM - Quick Installation methods Install with PyPI Once you have created your Python environment (Python 3.6+) you can simply type: pip3 install

Labrak Yanis 6 Nov 22, 2022
Long Expressive Memory (LEM)

Long Expressive Memory for Sequence Modeling This repository contains the implementation to reproduce the numerical experiments of the paper Long Expr

Konstantin Rusch 47 Dec 17, 2022
The implementation code for "DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruction"

DAGAN This is the official implementation code for DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruct

TensorLayer Community 159 Nov 22, 2022
Code and data for ImageCoDe, a contextual vison-and-language benchmark

ImageCoDe This repository contains code and data for ImageCoDe: Image Retrieval from Contextual Descriptions. Data All collected descriptions for the

McGill NLP 27 Dec 02, 2022
SeqFormer: a Frustratingly Simple Model for Video Instance Segmentation

SeqFormer: a Frustratingly Simple Model for Video Instance Segmentation SeqFormer SeqFormer: a Frustratingly Simple Model for Video Instance Segmentat

Junfeng Wu 298 Dec 22, 2022
Official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML)

Automated Learning Rate Scheduler for Large-Batch Training The official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th

Kakao Brain 35 Jan 04, 2023
WTTE-RNN a framework for churn and time to event prediction

WTTE-RNN Weibull Time To Event Recurrent Neural Network A less hacky machine-learning framework for churn- and time to event prediction. Forecasting p

Egil Martinsson 727 Dec 28, 2022