Adjusted launchers to new interface
This commit is contained in:
parent
7d479af0bb
commit
60e75e38dc
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
@ -8,6 +9,8 @@ from joblib import Parallel, delayed
|
|||||||
|
|
||||||
from tianshou.highlevel.experiment import Experiment
|
from tianshou.highlevel.experiment import Experiment
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JoblibConfig:
|
class JoblibConfig:
|
||||||
@ -21,14 +24,14 @@ class JoblibConfig:
|
|||||||
|
|
||||||
class ExpLauncher(ABC):
|
class ExpLauncher(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
def launch(self, experiments: Sequence[Experiment]) -> None:
|
||||||
raise NotImplementedError
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SequentialExpLauncher(ExpLauncher):
|
class SequentialExpLauncher(ExpLauncher):
|
||||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
def launch(self, experiments: Sequence[Experiment]) -> None:
|
||||||
for exp_name, exp in experiments.items():
|
for exp in experiments:
|
||||||
exp.run(exp_name)
|
exp.run()
|
||||||
|
|
||||||
|
|
||||||
class JoblibExpLauncher(ExpLauncher):
|
class JoblibExpLauncher(ExpLauncher):
|
||||||
@ -38,14 +41,14 @@ class JoblibExpLauncher(ExpLauncher):
|
|||||||
self.joblib_cfg.backend = "loky"
|
self.joblib_cfg.backend = "loky"
|
||||||
|
|
||||||
def launch(self, experiments: Sequence[Experiment]) -> None:
|
def launch(self, experiments: Sequence[Experiment]) -> None:
|
||||||
Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments)
|
Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def execute_task(exp: Experiment):
|
def _safe_execute(exp: Experiment):
|
||||||
try:
|
try:
|
||||||
exp.run()
|
exp.run()
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
print(e)
|
log.error(e)
|
||||||
|
|
||||||
|
|
||||||
class RegisteredExpLauncher(Enum):
|
class RegisteredExpLauncher(Enum):
|
||||||
@ -60,5 +63,5 @@ class RegisteredExpLauncher(Enum):
|
|||||||
return SequentialExpLauncher()
|
return SequentialExpLauncher()
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Launcher {self} is not implemented, registered launchers are {list(RegisteredExpLauncher)}.",
|
f"Launcher {self} is not yet implemented.",
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user