Adjusted launchers to new interface

This commit is contained in:
Michael Panchenko 2024-04-03 17:55:22 +02:00
parent 7d479af0bb
commit 60e75e38dc

View File

@ -1,3 +1,4 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, dataclass
@ -8,6 +9,8 @@ from joblib import Parallel, delayed
from tianshou.highlevel.experiment import Experiment
log = logging.getLogger(__name__)
@dataclass
class JoblibConfig:
@ -21,14 +24,14 @@ class JoblibConfig:
class ExpLauncher(ABC):
@abstractmethod
def launch(self, experiments: dict[str, Experiment]) -> None:
raise NotImplementedError
def launch(self, experiments: Sequence[Experiment]) -> None:
pass
class SequentialExpLauncher(ExpLauncher):
def launch(self, experiments: dict[str, Experiment]) -> None:
for exp_name, exp in experiments.items():
exp.run(exp_name)
def launch(self, experiments: Sequence[Experiment]) -> None:
for exp in experiments:
exp.run()
class JoblibExpLauncher(ExpLauncher):
@ -38,14 +41,14 @@ class JoblibExpLauncher(ExpLauncher):
self.joblib_cfg.backend = "loky"
def launch(self, experiments: Sequence[Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments)
Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments)
@staticmethod
def execute_task(exp: Experiment):
def _safe_execute(exp: Experiment):
try:
exp.run()
except Exception as e:
print(e)
except BaseException as e:
log.error(e)
class RegisteredExpLauncher(Enum):
@ -60,5 +63,5 @@ class RegisteredExpLauncher(Enum):
return SequentialExpLauncher()
case _:
raise NotImplementedError(
f"Launcher {self} is not implemented, registered launchers are {list(RegisteredExpLauncher)}.",
f"Launcher {self} is not yet implemented.",
)