diff --git a/examples/mujoco/joblib_launcher.py b/examples/mujoco/joblib_launcher.py new file mode 100644 index 0000000..f8d60a2 --- /dev/null +++ b/examples/mujoco/joblib_launcher.py @@ -0,0 +1,35 @@ +from dataclasses import asdict, dataclass +from typing import Literal + +from joblib import Parallel, delayed + +from tianshou.highlevel.experiment import Experiment + + +@dataclass +class JoblibConfig: + n_jobs: int = -1 + """The maximum number of concurrently running jobs. If -1, all CPUs are used.""" + backend: Literal["loky", "multiprocessing", "threading"] | None = None + """Allows to hard-code backend, otherwise inferred based on prefer and require.""" + verbose: int = 10 + """If greater than zero, prints progress messages.""" + + +class JoblibLauncher: + def __init__(self, joblib_cfg: JoblibConfig) -> None: + self.joblib_cfg = joblib_cfg + # Joblib's backend is hard-coded to loky since the threading backend produces different results + self.joblib_cfg.backend = "loky" + + def launch(self, experiments: dict[str, Experiment]) -> None: + Parallel(**asdict(self.joblib_cfg))( + delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items() + ) + + @staticmethod + def execute_task(exp: Experiment, name: str): + try: + exp.run(name) + except Exception as e: + print(e) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 703cf68..ea90bf5 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -1,13 +1,16 @@ #!/usr/bin/env python3 import os +import sys from collections.abc import Sequence from typing import Literal import torch +from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.env import VectorEnvType from tianshou.highlevel.evaluation import RLiableExperimentResult from tianshou.highlevel.experiment import ( ExperimentConfig, @@ -49,6 +52,7 @@ def main( value_clip: bool = False, norm_adv: bool = False, recompute_adv: bool = True, + run_sequential: bool = False, ) -> str: """Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained, @@ -75,6 +79,9 @@ def main( train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True, + venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT + if sys.platform == "darwin" + else VectorEnvType.SUBPROC_SHARED_MEM, ) experiments = ( @@ -105,8 +112,12 @@ def main( .build_default_seeded_experiments(num_experiments) ) - for experiment_name, experiment in experiments.items(): - experiment.run(experiment_name) + if run_sequential: + for experiment_name, experiment in experiments.items(): + experiment.run(experiment_name) + else: + launcher = JoblibLauncher(JoblibConfig()) + launcher.launch(experiments) return log_name diff --git a/poetry.lock b/poetry.lock index 4fbb9ad..4ec4eee 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2133,6 +2133,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.3.2" +description = "Lightweight pipelining with Python functions" +optional = true +python-versions = ">=3.7" +files = [ + {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, + {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, +] + [[package]] name = "json5" version = "0.9.14" @@ -6892,7 +6903,7 @@ atari = ["ale-py", "autorom", "opencv_python", "shimmy"] box2d = ["box2d_py", "pygame", "swig"] classic-control = ["pygame"] envpool = ["envpool"] -eval = ["rliable"] +eval = ["joblib", "rliable"] mujoco = ["imageio", "mujoco"] mujoco-py = ["cython", "mujoco-py"] pybullet = ["pybullet"] @@ -6902,4 +6913,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "0ac399a139f7e0e64f481823d6f079a385676abf4836d1b532bd0cb86589c66e" +content-hash = "bb0c689d4db6fe4de11d6bd55fe9fd4425739ef9a603ec5bd57f2c48e0ada9c7" diff --git a/pyproject.toml b/pyproject.toml index 5c3685d..7030d23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ docstring-parser = { version = "^0.15", optional = true } envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'darwin'"} gymnasium-robotics = { version = "*", optional = true } imageio = { version = ">=2.14.1", optional = true } +joblib = { version = "*", optional = true } jsonargparse = {version = "^4.24.1", optional = true} mujoco = { version = ">=2.1.5", optional = true } mujoco-py = { version = ">=2.1,<2.2", optional = true } @@ -79,7 +80,7 @@ pybullet = ["pybullet"] envpool = ["envpool"] robotics = ["gymnasium-robotics"] vizdoom = ["vizdoom"] -eval = ["rliable"] +eval = ["rliable", "joblib"] [tool.poetry.group.dev]