Tianshou/examples/mujoco/joblib_launcher.py
2024-03-27 17:38:01 +01:00

36 lines
1.2 KiB
Python

from dataclasses import asdict, dataclass
from typing import Literal
from joblib import Parallel, delayed
from tianshou.highlevel.experiment import Experiment
@dataclass
class JoblibConfig:
n_jobs: int = -1
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
backend: Literal["loky", "multiprocessing", "threading"] | None = None
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
verbose: int = 10
"""If greater than zero, prints progress messages."""
class JoblibLauncher:
def __init__(self, joblib_cfg: JoblibConfig) -> None:
self.joblib_cfg = joblib_cfg
# Joblib's backend is hard-coded to loky since the threading backend produces different results
self.joblib_cfg.backend = "loky"
def launch(self, experiments: dict[str, Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))(
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
)
@staticmethod
def execute_task(exp: Experiment, name: str):
try:
exp.run(name)
except Exception as e:
print(e)