# Changes ## Dependencies - New extra "eval" ## Api Extension - `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called - When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. - New method in `ExperimentConfig` called `build_default_seeded_experiments` - `SamplingConfig` has an explicit training seed, `test_seed` is inferred. - New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). Currently in alpha state. - Loggers can now restore the logged data into python by using the new `restore_logged_data` ## Breaking Changes - `AtariEnvFactory` (in examples) now receives explicit train and test seeds - `EnvFactoryRegistered` now requires an explicit `test_seed` - `BaseLogger.prepare_dict_for_logging` is now abstract --------- Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de> Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de> Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
"""Provides a basic interface for launching experiments. The API is experimental and subject to change!."""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Sequence
|
|
from copy import copy
|
|
from dataclasses import asdict, dataclass
|
|
from enum import Enum
|
|
from typing import Literal
|
|
|
|
from joblib import Parallel, delayed
|
|
|
|
from tianshou.highlevel.experiment import Experiment
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class JoblibConfig:
|
|
n_jobs: int = -1
|
|
"""The maximum number of concurrently running jobs. If -1, all CPUs are used."""
|
|
backend: Literal["loky", "multiprocessing", "threading"] | None = "loky"
|
|
"""Allows to hard-code backend, otherwise inferred based on prefer and require."""
|
|
verbose: int = 10
|
|
"""If greater than zero, prints progress messages."""
|
|
|
|
|
|
class ExpLauncher(ABC):
|
|
@abstractmethod
|
|
def launch(self, experiments: Sequence[Experiment]) -> None:
|
|
pass
|
|
|
|
|
|
class SequentialExpLauncher(ExpLauncher):
|
|
def launch(self, experiments: Sequence[Experiment]) -> None:
|
|
for exp in experiments:
|
|
exp.run()
|
|
|
|
|
|
class JoblibExpLauncher(ExpLauncher):
|
|
def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None:
|
|
self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
|
|
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
|
if self.joblib_cfg.backend != "loky":
|
|
log.warning(
|
|
f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. "
|
|
f"The current implementation requires loky to work and will be relaxed soon",
|
|
)
|
|
self.joblib_cfg.backend = "loky"
|
|
|
|
def launch(self, experiments: Sequence[Experiment]) -> None:
|
|
Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments)
|
|
|
|
@staticmethod
|
|
def _safe_execute(exp: Experiment) -> None:
|
|
try:
|
|
exp.run()
|
|
except BaseException as e:
|
|
log.error(e)
|
|
|
|
|
|
class RegisteredExpLauncher(Enum):
|
|
joblib = "joblib"
|
|
sequential = "sequential"
|
|
|
|
def create_launcher(self) -> ExpLauncher:
|
|
match self:
|
|
case RegisteredExpLauncher.joblib:
|
|
return JoblibExpLauncher()
|
|
case RegisteredExpLauncher.sequential:
|
|
return SequentialExpLauncher()
|
|
case _:
|
|
raise NotImplementedError(
|
|
f"Launcher {self} is not yet implemented.",
|
|
)
|