2024-04-21 01:25:33 +02:00
|
|
|
"""Provides a basic interface for launching experiments. The API is experimental and subject to change!."""
|
|
|
|
|
|
|
|
import logging
|
|
|
|
from abc import ABC, abstractmethod
|
2024-05-04 22:09:32 +02:00
|
|
|
from collections.abc import Callable, Sequence
|
2024-04-21 01:25:33 +02:00
|
|
|
from copy import copy
|
|
|
|
from dataclasses import asdict, dataclass
|
|
|
|
from enum import Enum
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
from joblib import Parallel, delayed
|
|
|
|
|
2024-05-04 22:09:32 +02:00
|
|
|
from tianshou.data import InfoStats
|
2024-04-21 01:25:33 +02:00
|
|
|
from tianshou.highlevel.experiment import Experiment
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class JoblibConfig:
|
|
|
|
n_jobs: int = -1
|
|
|
|
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
|
|
|
|
backend: Literal["loky", "multiprocessing", "threading"] | None = "loky"
|
|
|
|
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
|
|
|
|
verbose: int = 10
|
|
|
|
"""If greater than zero, prints progress messages."""
|
|
|
|
|
|
|
|
|
|
|
|
class ExpLauncher(ABC):
|
2024-05-04 22:09:32 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
experiment_runner: Callable[
|
|
|
|
[Experiment],
|
|
|
|
InfoStats | None,
|
|
|
|
] = lambda exp: exp.run().trainer_result,
|
|
|
|
):
|
|
|
|
""":param experiment_runner: can be used to override the default way in which an experiment is executed.
|
|
|
|
Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment
|
|
|
|
collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms
|
|
|
|
that are not yet supported by the high-level interfaces.
|
|
|
|
Passing this allows arbitrary things to happen during experiment execution, so use it with caution!
|
|
|
|
"""
|
|
|
|
self.experiment_runner = experiment_runner
|
|
|
|
|
2024-04-21 01:25:33 +02:00
|
|
|
@abstractmethod
|
2024-05-04 22:09:32 +02:00
|
|
|
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
|
|
|
"""Should call `self.experiment_runner` for each experiment in experiments and aggregate the results."""
|
|
|
|
|
|
|
|
def _safe_execute(self, exp: Experiment) -> InfoStats | None | Literal["failed"]:
|
|
|
|
try:
|
|
|
|
return self.experiment_runner(exp)
|
|
|
|
except BaseException as e:
|
|
|
|
log.error(f"Failed to run experiment {exp}.", exc_info=e)
|
|
|
|
return "failed"
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _return_from_successful_and_failed_exps(
|
|
|
|
successful_exp_stats: list[InfoStats | None],
|
|
|
|
failed_exps: list[Experiment],
|
|
|
|
) -> list[InfoStats | None]:
|
|
|
|
if not successful_exp_stats:
|
|
|
|
raise RuntimeError("All experiments failed, see error logs for more details.")
|
|
|
|
if failed_exps:
|
|
|
|
log.error(
|
|
|
|
f"Failed to run the following "
|
|
|
|
f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. "
|
|
|
|
f"See the logs for more details. "
|
|
|
|
f"Returning the results of {len(successful_exp_stats)} successful experiments.",
|
|
|
|
)
|
|
|
|
return successful_exp_stats
|
|
|
|
|
|
|
|
def launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
|
|
|
"""Will return the results of successfully executed experiments.
|
|
|
|
|
|
|
|
If a single experiment is passed, will not use parallelism and run it in the main process.
|
|
|
|
Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed.
|
|
|
|
"""
|
|
|
|
if len(experiments) == 1:
|
|
|
|
log.info(
|
|
|
|
"A single experiment is being run, will not use parallelism and run it in the main process.",
|
|
|
|
)
|
|
|
|
return [self.experiment_runner(experiments[0])]
|
|
|
|
return self._launch(experiments)
|
2024-04-21 01:25:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
class SequentialExpLauncher(ExpLauncher):
|
2024-05-04 22:09:32 +02:00
|
|
|
"""Convenience wrapper around a simple for loop to run experiments sequentially."""
|
|
|
|
|
|
|
|
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
|
|
|
successful_exp_stats = []
|
|
|
|
failed_exps = []
|
2024-04-21 01:25:33 +02:00
|
|
|
for exp in experiments:
|
2024-05-04 22:09:32 +02:00
|
|
|
for exp in experiments:
|
|
|
|
exp_stats = self._safe_execute(exp)
|
|
|
|
if exp_stats == "failed":
|
|
|
|
failed_exps.append(exp)
|
|
|
|
else:
|
|
|
|
successful_exp_stats.append(exp_stats)
|
|
|
|
# noinspection PyTypeChecker
|
|
|
|
return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps)
|
2024-04-21 01:25:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
class JoblibExpLauncher(ExpLauncher):
|
2024-05-04 22:09:32 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
joblib_cfg: JoblibConfig | None = None,
|
|
|
|
experiment_runner: Callable[
|
|
|
|
[Experiment],
|
|
|
|
InfoStats | None,
|
|
|
|
] = lambda exp: exp.run().trainer_result,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(experiment_runner=experiment_runner)
|
2024-04-21 01:25:33 +02:00
|
|
|
self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
|
|
|
|
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
|
|
|
if self.joblib_cfg.backend != "loky":
|
|
|
|
log.warning(
|
|
|
|
f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. "
|
|
|
|
f"The current implementation requires loky to work and will be relaxed soon",
|
|
|
|
)
|
|
|
|
self.joblib_cfg.backend = "loky"
|
|
|
|
|
2024-05-04 22:09:32 +02:00
|
|
|
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
|
|
|
results = Parallel(**asdict(self.joblib_cfg))(
|
|
|
|
delayed(self._safe_execute)(exp) for exp in experiments
|
|
|
|
)
|
|
|
|
successful_exps = []
|
|
|
|
failed_exps = []
|
|
|
|
for exp, result in zip(experiments, results, strict=True):
|
|
|
|
if result == "failed":
|
|
|
|
failed_exps.append(exp)
|
|
|
|
else:
|
|
|
|
successful_exps.append(result)
|
|
|
|
return self._return_from_successful_and_failed_exps(successful_exps, failed_exps)
|
2024-04-21 01:25:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
class RegisteredExpLauncher(Enum):
|
|
|
|
joblib = "joblib"
|
|
|
|
sequential = "sequential"
|
|
|
|
|
|
|
|
def create_launcher(self) -> ExpLauncher:
|
|
|
|
match self:
|
|
|
|
case RegisteredExpLauncher.joblib:
|
|
|
|
return JoblibExpLauncher()
|
|
|
|
case RegisteredExpLauncher.sequential:
|
|
|
|
return SequentialExpLauncher()
|
|
|
|
case _:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Launcher {self} is not yet implemented.",
|
|
|
|
)
|