added primitive joblib launcher
This commit is contained in:
parent
9c645ff4a0
commit
ec2c5c19d1
35
examples/mujoco/joblib_launcher.py
Normal file
35
examples/mujoco/joblib_launcher.py
Normal file
@ -0,0 +1,35 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from tianshou.highlevel.experiment import Experiment
|
||||
|
||||
|
||||
@dataclass
|
||||
class JoblibConfig:
|
||||
n_jobs: int = -1
|
||||
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
|
||||
backend: Literal["loky", "multiprocessing", "threading"] | None = None
|
||||
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
|
||||
verbose: int = 10
|
||||
"""If greater than zero, prints progress messages."""
|
||||
|
||||
|
||||
class JoblibLauncher:
|
||||
def __init__(self, joblib_cfg: JoblibConfig) -> None:
|
||||
self.joblib_cfg = joblib_cfg
|
||||
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
||||
self.joblib_cfg.backend = "loky"
|
||||
|
||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
||||
Parallel(**asdict(self.joblib_cfg))(
|
||||
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def execute_task(exp: Experiment, name: str):
|
||||
try:
|
||||
exp.run(name)
|
||||
except Exception as e:
|
||||
print(e)
|
@ -1,13 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import VectorEnvType
|
||||
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
@ -49,6 +52,7 @@ def main(
|
||||
value_clip: bool = False,
|
||||
norm_adv: bool = False,
|
||||
recompute_adv: bool = True,
|
||||
run_sequential: bool = False,
|
||||
) -> str:
|
||||
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for
|
||||
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
|
||||
@ -75,6 +79,9 @@ def main(
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
||||
if sys.platform == "darwin"
|
||||
else VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
)
|
||||
|
||||
experiments = (
|
||||
@ -105,8 +112,12 @@ def main(
|
||||
.build_default_seeded_experiments(num_experiments)
|
||||
)
|
||||
|
||||
for experiment_name, experiment in experiments.items():
|
||||
experiment.run(experiment_name)
|
||||
if run_sequential:
|
||||
for experiment_name, experiment in experiments.items():
|
||||
experiment.run(experiment_name)
|
||||
else:
|
||||
launcher = JoblibLauncher(JoblibConfig())
|
||||
launcher.launch(experiments)
|
||||
|
||||
return log_name
|
||||
|
||||
|
15
poetry.lock
generated
15
poetry.lock
generated
@ -2133,6 +2133,17 @@ MarkupSafe = ">=2.0"
|
||||
[package.extras]
|
||||
i18n = ["Babel (>=2.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.3.2"
|
||||
description = "Lightweight pipelining with Python functions"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"},
|
||||
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json5"
|
||||
version = "0.9.14"
|
||||
@ -6892,7 +6903,7 @@ atari = ["ale-py", "autorom", "opencv_python", "shimmy"]
|
||||
box2d = ["box2d_py", "pygame", "swig"]
|
||||
classic-control = ["pygame"]
|
||||
envpool = ["envpool"]
|
||||
eval = ["rliable"]
|
||||
eval = ["joblib", "rliable"]
|
||||
mujoco = ["imageio", "mujoco"]
|
||||
mujoco-py = ["cython", "mujoco-py"]
|
||||
pybullet = ["pybullet"]
|
||||
@ -6902,4 +6913,4 @@ vizdoom = ["vizdoom"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "0ac399a139f7e0e64f481823d6f079a385676abf4836d1b532bd0cb86589c66e"
|
||||
content-hash = "bb0c689d4db6fe4de11d6bd55fe9fd4425739ef9a603ec5bd57f2c48e0ada9c7"
|
||||
|
@ -58,6 +58,7 @@ docstring-parser = { version = "^0.15", optional = true }
|
||||
envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'darwin'"}
|
||||
gymnasium-robotics = { version = "*", optional = true }
|
||||
imageio = { version = ">=2.14.1", optional = true }
|
||||
joblib = { version = "*", optional = true }
|
||||
jsonargparse = {version = "^4.24.1", optional = true}
|
||||
mujoco = { version = ">=2.1.5", optional = true }
|
||||
mujoco-py = { version = ">=2.1,<2.2", optional = true }
|
||||
@ -79,7 +80,7 @@ pybullet = ["pybullet"]
|
||||
envpool = ["envpool"]
|
||||
robotics = ["gymnasium-robotics"]
|
||||
vizdoom = ["vizdoom"]
|
||||
eval = ["rliable"]
|
||||
eval = ["rliable", "joblib"]
|
||||
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
|
Loading…
x
Reference in New Issue
Block a user