added primitive joblib launcher

This commit is contained in:
Maximilian Huettenrauch 2024-03-27 17:38:01 +01:00
parent 9c645ff4a0
commit ec2c5c19d1
4 changed files with 63 additions and 5 deletions

View File

@ -0,0 +1,35 @@
from dataclasses import asdict, dataclass
from typing import Literal
from joblib import Parallel, delayed
from tianshou.highlevel.experiment import Experiment
@dataclass
class JoblibConfig:
n_jobs: int = -1
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
backend: Literal["loky", "multiprocessing", "threading"] | None = None
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
verbose: int = 10
"""If greater than zero, prints progress messages."""
class JoblibLauncher:
def __init__(self, joblib_cfg: JoblibConfig) -> None:
self.joblib_cfg = joblib_cfg
# Joblib's backend is hard-coded to loky since the threading backend produces different results
self.joblib_cfg.backend = "loky"
def launch(self, experiments: dict[str, Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))(
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
)
@staticmethod
def execute_task(exp: Experiment, name: str):
try:
exp.run(name)
except Exception as e:
print(e)

View File

@ -1,13 +1,16 @@
#!/usr/bin/env python3
import os
import sys
from collections.abc import Sequence
from typing import Literal
import torch
from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.evaluation import RLiableExperimentResult
from tianshou.highlevel.experiment import (
ExperimentConfig,
@ -49,6 +52,7 @@ def main(
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
run_sequential: bool = False,
) -> str:
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
@ -75,6 +79,9 @@ def main(
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
if sys.platform == "darwin"
else VectorEnvType.SUBPROC_SHARED_MEM,
)
experiments = (
@ -105,8 +112,12 @@ def main(
.build_default_seeded_experiments(num_experiments)
)
for experiment_name, experiment in experiments.items():
experiment.run(experiment_name)
if run_sequential:
for experiment_name, experiment in experiments.items():
experiment.run(experiment_name)
else:
launcher = JoblibLauncher(JoblibConfig())
launcher.launch(experiments)
return log_name

15
poetry.lock generated
View File

@ -2133,6 +2133,17 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
[[package]]
name = "joblib"
version = "1.3.2"
description = "Lightweight pipelining with Python functions"
optional = true
python-versions = ">=3.7"
files = [
{file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"},
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
]
[[package]]
name = "json5"
version = "0.9.14"
@ -6892,7 +6903,7 @@ atari = ["ale-py", "autorom", "opencv_python", "shimmy"]
box2d = ["box2d_py", "pygame", "swig"]
classic-control = ["pygame"]
envpool = ["envpool"]
eval = ["rliable"]
eval = ["joblib", "rliable"]
mujoco = ["imageio", "mujoco"]
mujoco-py = ["cython", "mujoco-py"]
pybullet = ["pybullet"]
@ -6902,4 +6913,4 @@ vizdoom = ["vizdoom"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "0ac399a139f7e0e64f481823d6f079a385676abf4836d1b532bd0cb86589c66e"
content-hash = "bb0c689d4db6fe4de11d6bd55fe9fd4425739ef9a603ec5bd57f2c48e0ada9c7"

View File

@ -58,6 +58,7 @@ docstring-parser = { version = "^0.15", optional = true }
envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'darwin'"}
gymnasium-robotics = { version = "*", optional = true }
imageio = { version = ">=2.14.1", optional = true }
joblib = { version = "*", optional = true }
jsonargparse = {version = "^4.24.1", optional = true}
mujoco = { version = ">=2.1.5", optional = true }
mujoco-py = { version = ">=2.1,<2.2", optional = true }
@ -79,7 +80,7 @@ pybullet = ["pybullet"]
envpool = ["envpool"]
robotics = ["gymnasium-robotics"]
vizdoom = ["vizdoom"]
eval = ["rliable"]
eval = ["rliable", "joblib"]
[tool.poetry.group.dev]