added primitive joblib launcher
This commit is contained in:
parent
9c645ff4a0
commit
ec2c5c19d1
35
examples/mujoco/joblib_launcher.py
Normal file
35
examples/mujoco/joblib_launcher.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
|
||||||
|
from tianshou.highlevel.experiment import Experiment
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JoblibConfig:
|
||||||
|
n_jobs: int = -1
|
||||||
|
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
|
||||||
|
backend: Literal["loky", "multiprocessing", "threading"] | None = None
|
||||||
|
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
|
||||||
|
verbose: int = 10
|
||||||
|
"""If greater than zero, prints progress messages."""
|
||||||
|
|
||||||
|
|
||||||
|
class JoblibLauncher:
|
||||||
|
def __init__(self, joblib_cfg: JoblibConfig) -> None:
|
||||||
|
self.joblib_cfg = joblib_cfg
|
||||||
|
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
||||||
|
self.joblib_cfg.backend = "loky"
|
||||||
|
|
||||||
|
def launch(self, experiments: dict[str, Experiment]) -> None:
|
||||||
|
Parallel(**asdict(self.joblib_cfg))(
|
||||||
|
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def execute_task(exp: Experiment, name: str):
|
||||||
|
try:
|
||||||
|
exp.run(name)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
@ -1,13 +1,16 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
|
from tianshou.highlevel.env import VectorEnvType
|
||||||
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
from tianshou.highlevel.evaluation import RLiableExperimentResult
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
@ -49,6 +52,7 @@ def main(
|
|||||||
value_clip: bool = False,
|
value_clip: bool = False,
|
||||||
norm_adv: bool = False,
|
norm_adv: bool = False,
|
||||||
recompute_adv: bool = True,
|
recompute_adv: bool = True,
|
||||||
|
run_sequential: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for
|
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for
|
||||||
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
|
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
|
||||||
@ -75,6 +79,9 @@ def main(
|
|||||||
train_seed=sampling_config.train_seed,
|
train_seed=sampling_config.train_seed,
|
||||||
test_seed=sampling_config.test_seed,
|
test_seed=sampling_config.test_seed,
|
||||||
obs_norm=True,
|
obs_norm=True,
|
||||||
|
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
||||||
|
if sys.platform == "darwin"
|
||||||
|
else VectorEnvType.SUBPROC_SHARED_MEM,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiments = (
|
experiments = (
|
||||||
@ -105,8 +112,12 @@ def main(
|
|||||||
.build_default_seeded_experiments(num_experiments)
|
.build_default_seeded_experiments(num_experiments)
|
||||||
)
|
)
|
||||||
|
|
||||||
for experiment_name, experiment in experiments.items():
|
if run_sequential:
|
||||||
experiment.run(experiment_name)
|
for experiment_name, experiment in experiments.items():
|
||||||
|
experiment.run(experiment_name)
|
||||||
|
else:
|
||||||
|
launcher = JoblibLauncher(JoblibConfig())
|
||||||
|
launcher.launch(experiments)
|
||||||
|
|
||||||
return log_name
|
return log_name
|
||||||
|
|
||||||
|
15
poetry.lock
generated
15
poetry.lock
generated
@ -2133,6 +2133,17 @@ MarkupSafe = ">=2.0"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
i18n = ["Babel (>=2.7)"]
|
i18n = ["Babel (>=2.7)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "joblib"
|
||||||
|
version = "1.3.2"
|
||||||
|
description = "Lightweight pipelining with Python functions"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"},
|
||||||
|
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "json5"
|
name = "json5"
|
||||||
version = "0.9.14"
|
version = "0.9.14"
|
||||||
@ -6892,7 +6903,7 @@ atari = ["ale-py", "autorom", "opencv_python", "shimmy"]
|
|||||||
box2d = ["box2d_py", "pygame", "swig"]
|
box2d = ["box2d_py", "pygame", "swig"]
|
||||||
classic-control = ["pygame"]
|
classic-control = ["pygame"]
|
||||||
envpool = ["envpool"]
|
envpool = ["envpool"]
|
||||||
eval = ["rliable"]
|
eval = ["joblib", "rliable"]
|
||||||
mujoco = ["imageio", "mujoco"]
|
mujoco = ["imageio", "mujoco"]
|
||||||
mujoco-py = ["cython", "mujoco-py"]
|
mujoco-py = ["cython", "mujoco-py"]
|
||||||
pybullet = ["pybullet"]
|
pybullet = ["pybullet"]
|
||||||
@ -6902,4 +6913,4 @@ vizdoom = ["vizdoom"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "0ac399a139f7e0e64f481823d6f079a385676abf4836d1b532bd0cb86589c66e"
|
content-hash = "bb0c689d4db6fe4de11d6bd55fe9fd4425739ef9a603ec5bd57f2c48e0ada9c7"
|
||||||
|
@ -58,6 +58,7 @@ docstring-parser = { version = "^0.15", optional = true }
|
|||||||
envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'darwin'"}
|
envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'darwin'"}
|
||||||
gymnasium-robotics = { version = "*", optional = true }
|
gymnasium-robotics = { version = "*", optional = true }
|
||||||
imageio = { version = ">=2.14.1", optional = true }
|
imageio = { version = ">=2.14.1", optional = true }
|
||||||
|
joblib = { version = "*", optional = true }
|
||||||
jsonargparse = {version = "^4.24.1", optional = true}
|
jsonargparse = {version = "^4.24.1", optional = true}
|
||||||
mujoco = { version = ">=2.1.5", optional = true }
|
mujoco = { version = ">=2.1.5", optional = true }
|
||||||
mujoco-py = { version = ">=2.1,<2.2", optional = true }
|
mujoco-py = { version = ">=2.1,<2.2", optional = true }
|
||||||
@ -79,7 +80,7 @@ pybullet = ["pybullet"]
|
|||||||
envpool = ["envpool"]
|
envpool = ["envpool"]
|
||||||
robotics = ["gymnasium-robotics"]
|
robotics = ["gymnasium-robotics"]
|
||||||
vizdoom = ["vizdoom"]
|
vizdoom = ["vizdoom"]
|
||||||
eval = ["rliable"]
|
eval = ["rliable", "joblib"]
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
[tool.poetry.group.dev]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user