diff --git a/examples/mujoco/joblib_launcher.py b/examples/mujoco/joblib_launcher.py deleted file mode 100644 index f8d60a2..0000000 --- a/examples/mujoco/joblib_launcher.py +++ /dev/null @@ -1,35 +0,0 @@ -from dataclasses import asdict, dataclass -from typing import Literal - -from joblib import Parallel, delayed - -from tianshou.highlevel.experiment import Experiment - - -@dataclass -class JoblibConfig: - n_jobs: int = -1 - """The maximum number of concurrently running jobs. If -1, all CPUs are used.""" - backend: Literal["loky", "multiprocessing", "threading"] | None = None - """Allows to hard-code backend, otherwise inferred based on prefer and require.""" - verbose: int = 10 - """If greater than zero, prints progress messages.""" - - -class JoblibLauncher: - def __init__(self, joblib_cfg: JoblibConfig) -> None: - self.joblib_cfg = joblib_cfg - # Joblib's backend is hard-coded to loky since the threading backend produces different results - self.joblib_cfg.backend = "loky" - - def launch(self, experiments: dict[str, Experiment]) -> None: - Parallel(**asdict(self.joblib_cfg))( - delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items() - ) - - @staticmethod - def execute_task(exp: Experiment, name: str): - try: - exp.run(name) - except Exception as e: - print(e) diff --git a/examples/mujoco/launcher.py b/examples/mujoco/launcher.py new file mode 100644 index 0000000..7d680f5 --- /dev/null +++ b/examples/mujoco/launcher.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Literal + +from joblib import Parallel, delayed + +from tianshou.highlevel.experiment import Experiment + + +@dataclass +class JoblibConfig: + n_jobs: int = -1 + """The maximum number of concurrently running jobs. If -1, all CPUs are used.""" + backend: Literal["loky", "multiprocessing", "threading"] | None = None + """Allows to hard-code backend, otherwise inferred based on prefer and require.""" + verbose: int = 10 + """If greater than zero, prints progress messages.""" + + +class ExpLauncher(ABC): + @abstractmethod + def launch(self, experiments: dict[str, Experiment]) -> None: + raise NotImplementedError + + +class SequentialExpLauncher(ExpLauncher): + def launch(self, experiments: dict[str, Experiment]) -> None: + for exp_name, exp in experiments.items(): + exp.run(exp_name) + + +class JoblibExpLauncher(ExpLauncher): + def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None: + self.joblib_cfg = joblib_cfg or JoblibConfig() + # Joblib's backend is hard-coded to loky since the threading backend produces different results + self.joblib_cfg.backend = "loky" + + def launch(self, experiments: dict[str, Experiment]) -> None: + Parallel(**asdict(self.joblib_cfg))( + delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items() + ) + + @staticmethod + def execute_task(exp: Experiment, name: str): + try: + exp.run(name) + except Exception as e: + print(e) + + +class RegisteredExpLauncher(Enum): + joblib = "joblib" + sequential = "sequential" + + def create_launcher(self): + match self: + case RegisteredExpLauncher.joblib: + return JoblibExpLauncher() + case RegisteredExpLauncher.sequential: + return SequentialExpLauncher() + case _: + raise NotImplementedError( + f"Launcher {self} is not implemented, registered launchers are {list(RegisteredExpLauncher)}.", + )