Tianshou/examples/mujoco/launcher.py

68 lines
2.0 KiB
Python
Raw Normal View History

2024-04-03 17:55:22 +02:00
import logging
2024-04-03 17:27:46 +02:00
from abc import ABC, abstractmethod
from collections.abc import Sequence
2024-04-03 17:27:46 +02:00
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Literal
from joblib import Parallel, delayed
from tianshou.highlevel.experiment import Experiment
2024-04-03 17:55:22 +02:00
log = logging.getLogger(__name__)
2024-04-03 17:27:46 +02:00
@dataclass
class JoblibConfig:
n_jobs: int = -1
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
backend: Literal["loky", "multiprocessing", "threading"] | None = None
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
verbose: int = 10
"""If greater than zero, prints progress messages."""
class ExpLauncher(ABC):
@abstractmethod
2024-04-03 17:55:22 +02:00
def launch(self, experiments: Sequence[Experiment]) -> None:
pass
2024-04-03 17:27:46 +02:00
class SequentialExpLauncher(ExpLauncher):
2024-04-03 17:55:22 +02:00
def launch(self, experiments: Sequence[Experiment]) -> None:
for exp in experiments:
exp.run()
2024-04-03 17:27:46 +02:00
class JoblibExpLauncher(ExpLauncher):
def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None:
self.joblib_cfg = joblib_cfg or JoblibConfig()
# Joblib's backend is hard-coded to loky since the threading backend produces different results
self.joblib_cfg.backend = "loky"
def launch(self, experiments: Sequence[Experiment]) -> None:
2024-04-03 17:55:22 +02:00
Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments)
2024-04-03 17:27:46 +02:00
@staticmethod
2024-04-03 17:55:22 +02:00
def _safe_execute(exp: Experiment):
2024-04-03 17:27:46 +02:00
try:
exp.run()
2024-04-03 17:55:22 +02:00
except BaseException as e:
log.error(e)
2024-04-03 17:27:46 +02:00
class RegisteredExpLauncher(Enum):
joblib = "joblib"
sequential = "sequential"
def create_launcher(self):
match self:
case RegisteredExpLauncher.joblib:
return JoblibExpLauncher()
case RegisteredExpLauncher.sequential:
return SequentialExpLauncher()
case _:
raise NotImplementedError(
2024-04-03 17:55:22 +02:00
f"Launcher {self} is not yet implemented.",
2024-04-03 17:27:46 +02:00
)