Adjusted launchers to new interface

This commit is contained in:
Michael Panchenko 2024-04-03 17:55:22 +02:00
parent 7d479af0bb
commit 60e75e38dc

View File

@ -1,3 +1,4 @@
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
@ -8,6 +9,8 @@ from joblib import Parallel, delayed
from tianshou.highlevel.experiment import Experiment from tianshou.highlevel.experiment import Experiment
log = logging.getLogger(__name__)
@dataclass @dataclass
class JoblibConfig: class JoblibConfig:
@ -21,14 +24,14 @@ class JoblibConfig:
class ExpLauncher(ABC): class ExpLauncher(ABC):
@abstractmethod @abstractmethod
def launch(self, experiments: dict[str, Experiment]) -> None: def launch(self, experiments: Sequence[Experiment]) -> None:
raise NotImplementedError pass
class SequentialExpLauncher(ExpLauncher): class SequentialExpLauncher(ExpLauncher):
def launch(self, experiments: dict[str, Experiment]) -> None: def launch(self, experiments: Sequence[Experiment]) -> None:
for exp_name, exp in experiments.items(): for exp in experiments:
exp.run(exp_name) exp.run()
class JoblibExpLauncher(ExpLauncher): class JoblibExpLauncher(ExpLauncher):
@ -38,14 +41,14 @@ class JoblibExpLauncher(ExpLauncher):
self.joblib_cfg.backend = "loky" self.joblib_cfg.backend = "loky"
def launch(self, experiments: Sequence[Experiment]) -> None: def launch(self, experiments: Sequence[Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments) Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments)
@staticmethod @staticmethod
def execute_task(exp: Experiment): def _safe_execute(exp: Experiment):
try: try:
exp.run() exp.run()
except Exception as e: except BaseException as e:
print(e) log.error(e)
class RegisteredExpLauncher(Enum): class RegisteredExpLauncher(Enum):
@ -60,5 +63,5 @@ class RegisteredExpLauncher(Enum):
return SequentialExpLauncher() return SequentialExpLauncher()
case _: case _:
raise NotImplementedError( raise NotImplementedError(
f"Launcher {self} is not implemented, registered launchers are {list(RegisteredExpLauncher)}.", f"Launcher {self} is not yet implemented.",
) )