From bf3859a457557f2cd2822af5d598c9d0ba6027ee Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 4 May 2024 22:09:32 +0200 Subject: [PATCH] Extension of ExpLauncher and DataclassPPrintMixin 1. Launch in main process if only 1 exp is passed 2. Launcher returns a list of stats for successful exps 3. More detailed logging for unsuccessful expos 4. Raise error if all runs were unsuccessful 5. DataclassPPrintMixin allows retrieving a pretty repr string 6. Minor improvements in docstrings --- CHANGELOG.md | 10 ++- examples/mujoco/mujoco_ppo_hl_multi.py | 40 ++++++---- tianshou/evaluation/launcher.py | 104 +++++++++++++++++++++---- tianshou/highlevel/config.py | 2 +- tianshou/highlevel/experiment.py | 10 ++- tianshou/utils/print.py | 19 +++-- 6 files changed, 142 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 807a9da..886c68f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## Release 1.1.0 ### Api Extensions +- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141 + - The module `evaluation.launchers` for parallelization is currently in alpha state. - `data`: - `Batch`: - Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098 @@ -24,7 +26,7 @@ - `SamplingConfig`: - Add support for `batch_size=None`. #1077 - Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074 - - `highlevel.experiment`: + - `experiment`: - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and which determines the default run name and therefore the persistence subdirectory. It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than @@ -34,8 +36,8 @@ - Add method `build_seeded_collection` for the sound creation of multiple experiments with varying random seeds #1131 - Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131 -- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 - - The module `evaluation.launchers` for parallelization is currently in alpha state. + - `env`: + - Added new `VectorEnvType` called `SUBPROC_SHARED_MEM_AUTO` and used in for Atari and Mujoco venv creation. #1141 - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 - `utils`: - `net.continuous.Critic`: @@ -45,6 +47,8 @@ to reuse the actor's preprocessing network #1128 - `torch_utils` (new module) - Added context managers `torch_train_mode` and `policy_within_training_step` #1123 + - `print` + - `DataclassPPrintMixin` now supports outputting a string, not just printing the pretty repr. #1141 ### Fixes - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 6f67de3..33ec34a 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -6,16 +6,13 @@ For each repetition, a policy seed, train env seeds, and test env seeds are set are non-intersecting with the seeds of the other experiments. Each experiment's results are stored in a separate subdirectory. -The final results are aggregated and turned into useful statistics with the rliable API. +The final results are aggregated and turned into useful statistics with the rliable package. The call to `eval_experiments` will load the results from the log directory and create an interp-quantile mean plot for the returns as well as a performance profile plot. These plots are saved in the log directory and displayed in the console. """ import os -import sys -from collections.abc import Sequence -from typing import Literal import torch @@ -40,14 +37,14 @@ log = logging.getLogger(__name__) def main( - num_experiments: int = 2, - run_experiments_sequentially: bool = True, + num_experiments: int = 5, + run_experiments_sequentially: bool = False, ) -> RLiableExperimentResult: - """:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. - LIMITATIONS: currently, the parallel execution does not seem to work properly on linux. - It might generally be undesired to run multiple experiments in parallel on the same machine, - as a single experiment already uses all available CPU cores by default. - :return: the directory where the results are stored + """:param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. + :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. + If a single experiment is set to use all available CPU cores, + it might be undesired to run multiple experiments in parallel on the same machine, + :return: an object containing rliable-based evaluation results """ task = "Ant-v4" persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) @@ -58,12 +55,12 @@ def main( num_epochs=1, step_per_epoch=5000, batch_size=64, - num_train_envs=10, - num_test_envs=10, - num_test_episodes=10, + num_train_envs=5, + num_test_envs=5, + num_test_episodes=5, buffer_size=4096, step_per_collect=2048, - repeat_per_collect=10, + repeat_per_collect=1, ) env_factory = MujocoEnvFactory( @@ -106,7 +103,18 @@ def main( launcher = RegisteredExpLauncher.sequential.create_launcher() else: launcher = RegisteredExpLauncher.joblib.create_launcher() - experiment_collection.run(launcher) + successful_experiment_stats = experiment_collection.run(launcher) + log.info(f"Successfully completed {len(successful_experiment_stats)} experiments.") + + num_successful_experiments = len(successful_experiment_stats) + for i, info_stats in enumerate(successful_experiment_stats, start=1): + if info_stats is not None: + log.info(f"Training stats for successful experiment {i}/{num_successful_experiments}:") + log.info(info_stats.pprints_asdict()) + else: + log.info( + f"No training stats available for successful experiment {i}/{num_successful_experiments}.", + ) rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir) rliable_result.eval_results(show_plots=True, save_plots=True) diff --git a/tianshou/evaluation/launcher.py b/tianshou/evaluation/launcher.py index db97519..534e5f8 100644 --- a/tianshou/evaluation/launcher.py +++ b/tianshou/evaluation/launcher.py @@ -2,7 +2,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from copy import copy from dataclasses import asdict, dataclass from enum import Enum @@ -10,6 +10,7 @@ from typing import Literal from joblib import Parallel, delayed +from tianshou.data import InfoStats from tianshou.highlevel.experiment import Experiment log = logging.getLogger(__name__) @@ -26,19 +27,89 @@ class JoblibConfig: class ExpLauncher(ABC): + def __init__( + self, + experiment_runner: Callable[ + [Experiment], + InfoStats | None, + ] = lambda exp: exp.run().trainer_result, + ): + """:param experiment_runner: can be used to override the default way in which an experiment is executed. + Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment + collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms + that are not yet supported by the high-level interfaces. + Passing this allows arbitrary things to happen during experiment execution, so use it with caution! + """ + self.experiment_runner = experiment_runner + @abstractmethod - def launch(self, experiments: Sequence[Experiment]) -> None: - pass + def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + """Should call `self.experiment_runner` for each experiment in experiments and aggregate the results.""" + + def _safe_execute(self, exp: Experiment) -> InfoStats | None | Literal["failed"]: + try: + return self.experiment_runner(exp) + except BaseException as e: + log.error(f"Failed to run experiment {exp}.", exc_info=e) + return "failed" + + @staticmethod + def _return_from_successful_and_failed_exps( + successful_exp_stats: list[InfoStats | None], + failed_exps: list[Experiment], + ) -> list[InfoStats | None]: + if not successful_exp_stats: + raise RuntimeError("All experiments failed, see error logs for more details.") + if failed_exps: + log.error( + f"Failed to run the following " + f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. " + f"See the logs for more details. " + f"Returning the results of {len(successful_exp_stats)} successful experiments.", + ) + return successful_exp_stats + + def launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + """Will return the results of successfully executed experiments. + + If a single experiment is passed, will not use parallelism and run it in the main process. + Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed. + """ + if len(experiments) == 1: + log.info( + "A single experiment is being run, will not use parallelism and run it in the main process.", + ) + return [self.experiment_runner(experiments[0])] + return self._launch(experiments) class SequentialExpLauncher(ExpLauncher): - def launch(self, experiments: Sequence[Experiment]) -> None: + """Convenience wrapper around a simple for loop to run experiments sequentially.""" + + def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + successful_exp_stats = [] + failed_exps = [] for exp in experiments: - exp.run() + for exp in experiments: + exp_stats = self._safe_execute(exp) + if exp_stats == "failed": + failed_exps.append(exp) + else: + successful_exp_stats.append(exp_stats) + # noinspection PyTypeChecker + return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps) class JoblibExpLauncher(ExpLauncher): - def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None: + def __init__( + self, + joblib_cfg: JoblibConfig | None = None, + experiment_runner: Callable[ + [Experiment], + InfoStats | None, + ] = lambda exp: exp.run().trainer_result, + ) -> None: + super().__init__(experiment_runner=experiment_runner) self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig() # Joblib's backend is hard-coded to loky since the threading backend produces different results if self.joblib_cfg.backend != "loky": @@ -48,15 +119,18 @@ class JoblibExpLauncher(ExpLauncher): ) self.joblib_cfg.backend = "loky" - def launch(self, experiments: Sequence[Experiment]) -> None: - Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments) - - @staticmethod - def _safe_execute(exp: Experiment) -> None: - try: - exp.run() - except BaseException as e: - log.error(e) + def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: + results = Parallel(**asdict(self.joblib_cfg))( + delayed(self._safe_execute)(exp) for exp in experiments + ) + successful_exps = [] + failed_exps = [] + for exp, result in zip(experiments, results, strict=True): + if result == "failed": + failed_exps.append(exp) + else: + successful_exps.append(result) + return self._return_from_successful_and_failed_exps(successful_exps, failed_exps) class RegisteredExpLauncher(Enum): diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index f8ca0c0..951f2f3 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -85,7 +85,7 @@ class SamplingConfig(ToStringMixin): """ controls, within one gradient update step of an on-policy algorithm, the number of times an actual gradient update is applied using the full collected dataset, i.e. if the parameter is - `n`, then the collected data shall be used five times to update the policy within the same + 5, then the collected data shall be used five times to update the policy within the same training step. The parameter is ignored and may be set to None for off-policy and offline algorithms. diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 99aadc2..dbcd3f1 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -87,7 +87,6 @@ from tianshou.utils.string import ToStringMixin if TYPE_CHECKING: from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher - log = logging.getLogger(__name__) @@ -356,15 +355,20 @@ class Experiment(ToStringMixin): class ExperimentCollection: + """Shallow wrapper around a list of experiments providing a simple interface for running them with a launcher.""" + def __init__(self, experiments: list[Experiment]): self.experiments = experiments - def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None: + def run( + self, + launcher: Union["ExpLauncher", "RegisteredExpLauncher"], + ) -> list[InfoStats | None]: from tianshou.evaluation.launcher import RegisteredExpLauncher if isinstance(launcher, RegisteredExpLauncher): launcher = launcher.create_launcher() - launcher.launch(experiments=self.experiments) + return launcher.launch(experiments=self.experiments) class ExperimentBuilder: diff --git a/tianshou/utils/print.py b/tianshou/utils/print.py index 82afc09..88035ba 100644 --- a/tianshou/utils/print.py +++ b/tianshou/utils/print.py @@ -1,20 +1,29 @@ +import pprint from collections.abc import Sequence from dataclasses import asdict, dataclass -from pprint import pprint @dataclass class DataclassPPrintMixin: - def pprint_asdict(self, exclude_fields: Sequence[str] | None = None) -> None: + def pprint_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> None: """Pretty-print the object as a dict, excluding specified fields. :param exclude_fields: A sequence of field names to exclude from the output. If None, no fields are excluded. + :param indent: The indentation to use when pretty-printing. """ - print(f"{self.__class__.__name__}") - print("----------------------------------------") + print(self.pprints_asdict(exclude_fields=exclude_fields, indent=indent)) + + def pprints_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> str: + """String corresponding to pretty-print of the object as a dict, excluding specified fields. + + :param exclude_fields: A sequence of field names to exclude from the output. + If None, no fields are excluded. + :param indent: The indentation to use when pretty-printing. + """ + prefix = f"{self.__class__.__name__}\n----------------------------------------\n" print_dict = asdict(self) exclude_fields = exclude_fields or [] for field in exclude_fields: print_dict.pop(field, None) - pprint(print_dict) + return prefix + pprint.pformat(print_dict, indent=indent)