Added launcher interface and registry
This commit is contained in:
parent
f2e10b04bb
commit
85e910ec5d
@ -1,35 +0,0 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Literal
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from tianshou.highlevel.experiment import Experiment
|
||||
|
||||
|
||||
@dataclass
|
||||
class JoblibConfig:
|
||||
n_jobs: int = -1
|
||||
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
|
||||
backend: Literal["loky", "multiprocessing", "threading"] | None = None
|
||||
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
|
||||
verbose: int = 10
|
||||
"""If greater than zero, prints progress messages."""
|
||||
|
||||
|
||||
class JoblibLauncher:
|
||||
def __init__(self, joblib_cfg: JoblibConfig) -> None:
|
||||
self.joblib_cfg = joblib_cfg
|
||||
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
||||
self.joblib_cfg.backend = "loky"
|
||||
|
||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
||||
Parallel(**asdict(self.joblib_cfg))(
|
||||
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def execute_task(exp: Experiment, name: str):
|
||||
try:
|
||||
exp.run(name)
|
||||
except Exception as e:
|
||||
print(e)
|
65
examples/mujoco/launcher.py
Normal file
65
examples/mujoco/launcher.py
Normal file
@ -0,0 +1,65 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from tianshou.highlevel.experiment import Experiment
|
||||
|
||||
|
||||
@dataclass
|
||||
class JoblibConfig:
|
||||
n_jobs: int = -1
|
||||
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
|
||||
backend: Literal["loky", "multiprocessing", "threading"] | None = None
|
||||
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
|
||||
verbose: int = 10
|
||||
"""If greater than zero, prints progress messages."""
|
||||
|
||||
|
||||
class ExpLauncher(ABC):
|
||||
@abstractmethod
|
||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SequentialExpLauncher(ExpLauncher):
|
||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
||||
for exp_name, exp in experiments.items():
|
||||
exp.run(exp_name)
|
||||
|
||||
|
||||
class JoblibExpLauncher(ExpLauncher):
|
||||
def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None:
|
||||
self.joblib_cfg = joblib_cfg or JoblibConfig()
|
||||
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
||||
self.joblib_cfg.backend = "loky"
|
||||
|
||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
||||
Parallel(**asdict(self.joblib_cfg))(
|
||||
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def execute_task(exp: Experiment, name: str):
|
||||
try:
|
||||
exp.run(name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
class RegisteredExpLauncher(Enum):
|
||||
joblib = "joblib"
|
||||
sequential = "sequential"
|
||||
|
||||
def create_launcher(self):
|
||||
match self:
|
||||
case RegisteredExpLauncher.joblib:
|
||||
return JoblibExpLauncher()
|
||||
case RegisteredExpLauncher.sequential:
|
||||
return SequentialExpLauncher()
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f"Launcher {self} is not implemented, registered launchers are {list(RegisteredExpLauncher)}.",
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user