Extension of ExpLauncher and DataclassPPrintMixin
1. Launch in main process if only 1 exp is passed 2. Launcher returns a list of stats for successful exps 3. More detailed logging for unsuccessful expos 4. Raise error if all runs were unsuccessful 5. DataclassPPrintMixin allows retrieving a pretty repr string 6. Minor improvements in docstrings
This commit is contained in:
parent
1cd22f1d32
commit
bf3859a457
10
CHANGELOG.md
10
CHANGELOG.md
@ -3,6 +3,8 @@
|
|||||||
## Release 1.1.0
|
## Release 1.1.0
|
||||||
|
|
||||||
### Api Extensions
|
### Api Extensions
|
||||||
|
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141
|
||||||
|
- The module `evaluation.launchers` for parallelization is currently in alpha state.
|
||||||
- `data`:
|
- `data`:
|
||||||
- `Batch`:
|
- `Batch`:
|
||||||
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
|
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
|
||||||
@ -24,7 +26,7 @@
|
|||||||
- `SamplingConfig`:
|
- `SamplingConfig`:
|
||||||
- Add support for `batch_size=None`. #1077
|
- Add support for `batch_size=None`. #1077
|
||||||
- Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074
|
- Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074
|
||||||
- `highlevel.experiment`:
|
- `experiment`:
|
||||||
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
|
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
|
||||||
which determines the default run name and therefore the persistence subdirectory.
|
which determines the default run name and therefore the persistence subdirectory.
|
||||||
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
|
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
|
||||||
@ -34,8 +36,8 @@
|
|||||||
- Add method `build_seeded_collection` for the sound creation of multiple
|
- Add method `build_seeded_collection` for the sound creation of multiple
|
||||||
experiments with varying random seeds #1131
|
experiments with varying random seeds #1131
|
||||||
- Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131
|
- Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131
|
||||||
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074
|
- `env`:
|
||||||
- The module `evaluation.launchers` for parallelization is currently in alpha state.
|
- Added new `VectorEnvType` called `SUBPROC_SHARED_MEM_AUTO` and used in for Atari and Mujoco venv creation. #1141
|
||||||
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
|
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
|
||||||
- `utils`:
|
- `utils`:
|
||||||
- `net.continuous.Critic`:
|
- `net.continuous.Critic`:
|
||||||
@ -45,6 +47,8 @@
|
|||||||
to reuse the actor's preprocessing network #1128
|
to reuse the actor's preprocessing network #1128
|
||||||
- `torch_utils` (new module)
|
- `torch_utils` (new module)
|
||||||
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123
|
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123
|
||||||
|
- `print`
|
||||||
|
- `DataclassPPrintMixin` now supports outputting a string, not just printing the pretty repr. #1141
|
||||||
|
|
||||||
### Fixes
|
### Fixes
|
||||||
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
|
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
|
||||||
|
@ -6,16 +6,13 @@ For each repetition, a policy seed, train env seeds, and test env seeds are set
|
|||||||
are non-intersecting with the seeds of the other experiments.
|
are non-intersecting with the seeds of the other experiments.
|
||||||
Each experiment's results are stored in a separate subdirectory.
|
Each experiment's results are stored in a separate subdirectory.
|
||||||
|
|
||||||
The final results are aggregated and turned into useful statistics with the rliable API.
|
The final results are aggregated and turned into useful statistics with the rliable package.
|
||||||
The call to `eval_experiments` will load the results from the log directory and
|
The call to `eval_experiments` will load the results from the log directory and
|
||||||
create an interp-quantile mean plot for the returns as well as a performance profile plot.
|
create an interp-quantile mean plot for the returns as well as a performance profile plot.
|
||||||
These plots are saved in the log directory and displayed in the console.
|
These plots are saved in the log directory and displayed in the console.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -40,14 +37,14 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
num_experiments: int = 2,
|
num_experiments: int = 5,
|
||||||
run_experiments_sequentially: bool = True,
|
run_experiments_sequentially: bool = False,
|
||||||
) -> RLiableExperimentResult:
|
) -> RLiableExperimentResult:
|
||||||
""":param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
|
""":param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds.
|
||||||
LIMITATIONS: currently, the parallel execution does not seem to work properly on linux.
|
:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
|
||||||
It might generally be undesired to run multiple experiments in parallel on the same machine,
|
If a single experiment is set to use all available CPU cores,
|
||||||
as a single experiment already uses all available CPU cores by default.
|
it might be undesired to run multiple experiments in parallel on the same machine,
|
||||||
:return: the directory where the results are stored
|
:return: an object containing rliable-based evaluation results
|
||||||
"""
|
"""
|
||||||
task = "Ant-v4"
|
task = "Ant-v4"
|
||||||
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))
|
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))
|
||||||
@ -58,12 +55,12 @@ def main(
|
|||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
step_per_epoch=5000,
|
step_per_epoch=5000,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
num_train_envs=10,
|
num_train_envs=5,
|
||||||
num_test_envs=10,
|
num_test_envs=5,
|
||||||
num_test_episodes=10,
|
num_test_episodes=5,
|
||||||
buffer_size=4096,
|
buffer_size=4096,
|
||||||
step_per_collect=2048,
|
step_per_collect=2048,
|
||||||
repeat_per_collect=10,
|
repeat_per_collect=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(
|
env_factory = MujocoEnvFactory(
|
||||||
@ -106,7 +103,18 @@ def main(
|
|||||||
launcher = RegisteredExpLauncher.sequential.create_launcher()
|
launcher = RegisteredExpLauncher.sequential.create_launcher()
|
||||||
else:
|
else:
|
||||||
launcher = RegisteredExpLauncher.joblib.create_launcher()
|
launcher = RegisteredExpLauncher.joblib.create_launcher()
|
||||||
experiment_collection.run(launcher)
|
successful_experiment_stats = experiment_collection.run(launcher)
|
||||||
|
log.info(f"Successfully completed {len(successful_experiment_stats)} experiments.")
|
||||||
|
|
||||||
|
num_successful_experiments = len(successful_experiment_stats)
|
||||||
|
for i, info_stats in enumerate(successful_experiment_stats, start=1):
|
||||||
|
if info_stats is not None:
|
||||||
|
log.info(f"Training stats for successful experiment {i}/{num_successful_experiments}:")
|
||||||
|
log.info(info_stats.pprints_asdict())
|
||||||
|
else:
|
||||||
|
log.info(
|
||||||
|
f"No training stats available for successful experiment {i}/{num_successful_experiments}.",
|
||||||
|
)
|
||||||
|
|
||||||
rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
|
rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
|
||||||
rliable_result.eval_results(show_plots=True, save_plots=True)
|
rliable_result.eval_results(show_plots=True, save_plots=True)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -10,6 +10,7 @@ from typing import Literal
|
|||||||
|
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
|
||||||
|
from tianshou.data import InfoStats
|
||||||
from tianshou.highlevel.experiment import Experiment
|
from tianshou.highlevel.experiment import Experiment
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -26,19 +27,89 @@ class JoblibConfig:
|
|||||||
|
|
||||||
|
|
||||||
class ExpLauncher(ABC):
|
class ExpLauncher(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experiment_runner: Callable[
|
||||||
|
[Experiment],
|
||||||
|
InfoStats | None,
|
||||||
|
] = lambda exp: exp.run().trainer_result,
|
||||||
|
):
|
||||||
|
""":param experiment_runner: can be used to override the default way in which an experiment is executed.
|
||||||
|
Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment
|
||||||
|
collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms
|
||||||
|
that are not yet supported by the high-level interfaces.
|
||||||
|
Passing this allows arbitrary things to happen during experiment execution, so use it with caution!
|
||||||
|
"""
|
||||||
|
self.experiment_runner = experiment_runner
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def launch(self, experiments: Sequence[Experiment]) -> None:
|
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
||||||
pass
|
"""Should call `self.experiment_runner` for each experiment in experiments and aggregate the results."""
|
||||||
|
|
||||||
|
def _safe_execute(self, exp: Experiment) -> InfoStats | None | Literal["failed"]:
|
||||||
|
try:
|
||||||
|
return self.experiment_runner(exp)
|
||||||
|
except BaseException as e:
|
||||||
|
log.error(f"Failed to run experiment {exp}.", exc_info=e)
|
||||||
|
return "failed"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _return_from_successful_and_failed_exps(
|
||||||
|
successful_exp_stats: list[InfoStats | None],
|
||||||
|
failed_exps: list[Experiment],
|
||||||
|
) -> list[InfoStats | None]:
|
||||||
|
if not successful_exp_stats:
|
||||||
|
raise RuntimeError("All experiments failed, see error logs for more details.")
|
||||||
|
if failed_exps:
|
||||||
|
log.error(
|
||||||
|
f"Failed to run the following "
|
||||||
|
f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. "
|
||||||
|
f"See the logs for more details. "
|
||||||
|
f"Returning the results of {len(successful_exp_stats)} successful experiments.",
|
||||||
|
)
|
||||||
|
return successful_exp_stats
|
||||||
|
|
||||||
|
def launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
||||||
|
"""Will return the results of successfully executed experiments.
|
||||||
|
|
||||||
|
If a single experiment is passed, will not use parallelism and run it in the main process.
|
||||||
|
Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed.
|
||||||
|
"""
|
||||||
|
if len(experiments) == 1:
|
||||||
|
log.info(
|
||||||
|
"A single experiment is being run, will not use parallelism and run it in the main process.",
|
||||||
|
)
|
||||||
|
return [self.experiment_runner(experiments[0])]
|
||||||
|
return self._launch(experiments)
|
||||||
|
|
||||||
|
|
||||||
class SequentialExpLauncher(ExpLauncher):
|
class SequentialExpLauncher(ExpLauncher):
|
||||||
def launch(self, experiments: Sequence[Experiment]) -> None:
|
"""Convenience wrapper around a simple for loop to run experiments sequentially."""
|
||||||
|
|
||||||
|
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
||||||
|
successful_exp_stats = []
|
||||||
|
failed_exps = []
|
||||||
for exp in experiments:
|
for exp in experiments:
|
||||||
exp.run()
|
for exp in experiments:
|
||||||
|
exp_stats = self._safe_execute(exp)
|
||||||
|
if exp_stats == "failed":
|
||||||
|
failed_exps.append(exp)
|
||||||
|
else:
|
||||||
|
successful_exp_stats.append(exp_stats)
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps)
|
||||||
|
|
||||||
|
|
||||||
class JoblibExpLauncher(ExpLauncher):
|
class JoblibExpLauncher(ExpLauncher):
|
||||||
def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
joblib_cfg: JoblibConfig | None = None,
|
||||||
|
experiment_runner: Callable[
|
||||||
|
[Experiment],
|
||||||
|
InfoStats | None,
|
||||||
|
] = lambda exp: exp.run().trainer_result,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(experiment_runner=experiment_runner)
|
||||||
self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
|
self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
|
||||||
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
||||||
if self.joblib_cfg.backend != "loky":
|
if self.joblib_cfg.backend != "loky":
|
||||||
@ -48,15 +119,18 @@ class JoblibExpLauncher(ExpLauncher):
|
|||||||
)
|
)
|
||||||
self.joblib_cfg.backend = "loky"
|
self.joblib_cfg.backend = "loky"
|
||||||
|
|
||||||
def launch(self, experiments: Sequence[Experiment]) -> None:
|
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
|
||||||
Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments)
|
results = Parallel(**asdict(self.joblib_cfg))(
|
||||||
|
delayed(self._safe_execute)(exp) for exp in experiments
|
||||||
@staticmethod
|
)
|
||||||
def _safe_execute(exp: Experiment) -> None:
|
successful_exps = []
|
||||||
try:
|
failed_exps = []
|
||||||
exp.run()
|
for exp, result in zip(experiments, results, strict=True):
|
||||||
except BaseException as e:
|
if result == "failed":
|
||||||
log.error(e)
|
failed_exps.append(exp)
|
||||||
|
else:
|
||||||
|
successful_exps.append(result)
|
||||||
|
return self._return_from_successful_and_failed_exps(successful_exps, failed_exps)
|
||||||
|
|
||||||
|
|
||||||
class RegisteredExpLauncher(Enum):
|
class RegisteredExpLauncher(Enum):
|
||||||
|
@ -85,7 +85,7 @@ class SamplingConfig(ToStringMixin):
|
|||||||
"""
|
"""
|
||||||
controls, within one gradient update step of an on-policy algorithm, the number of times an
|
controls, within one gradient update step of an on-policy algorithm, the number of times an
|
||||||
actual gradient update is applied using the full collected dataset, i.e. if the parameter is
|
actual gradient update is applied using the full collected dataset, i.e. if the parameter is
|
||||||
`n`, then the collected data shall be used five times to update the policy within the same
|
5, then the collected data shall be used five times to update the policy within the same
|
||||||
training step.
|
training step.
|
||||||
|
|
||||||
The parameter is ignored and may be set to None for off-policy and offline algorithms.
|
The parameter is ignored and may be set to None for off-policy and offline algorithms.
|
||||||
|
@ -87,7 +87,6 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher
|
from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -356,15 +355,20 @@ class Experiment(ToStringMixin):
|
|||||||
|
|
||||||
|
|
||||||
class ExperimentCollection:
|
class ExperimentCollection:
|
||||||
|
"""Shallow wrapper around a list of experiments providing a simple interface for running them with a launcher."""
|
||||||
|
|
||||||
def __init__(self, experiments: list[Experiment]):
|
def __init__(self, experiments: list[Experiment]):
|
||||||
self.experiments = experiments
|
self.experiments = experiments
|
||||||
|
|
||||||
def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None:
|
def run(
|
||||||
|
self,
|
||||||
|
launcher: Union["ExpLauncher", "RegisteredExpLauncher"],
|
||||||
|
) -> list[InfoStats | None]:
|
||||||
from tianshou.evaluation.launcher import RegisteredExpLauncher
|
from tianshou.evaluation.launcher import RegisteredExpLauncher
|
||||||
|
|
||||||
if isinstance(launcher, RegisteredExpLauncher):
|
if isinstance(launcher, RegisteredExpLauncher):
|
||||||
launcher = launcher.create_launcher()
|
launcher = launcher.create_launcher()
|
||||||
launcher.launch(experiments=self.experiments)
|
return launcher.launch(experiments=self.experiments)
|
||||||
|
|
||||||
|
|
||||||
class ExperimentBuilder:
|
class ExperimentBuilder:
|
||||||
|
@ -1,20 +1,29 @@
|
|||||||
|
import pprint
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataclassPPrintMixin:
|
class DataclassPPrintMixin:
|
||||||
def pprint_asdict(self, exclude_fields: Sequence[str] | None = None) -> None:
|
def pprint_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> None:
|
||||||
"""Pretty-print the object as a dict, excluding specified fields.
|
"""Pretty-print the object as a dict, excluding specified fields.
|
||||||
|
|
||||||
:param exclude_fields: A sequence of field names to exclude from the output.
|
:param exclude_fields: A sequence of field names to exclude from the output.
|
||||||
If None, no fields are excluded.
|
If None, no fields are excluded.
|
||||||
|
:param indent: The indentation to use when pretty-printing.
|
||||||
"""
|
"""
|
||||||
print(f"{self.__class__.__name__}")
|
print(self.pprints_asdict(exclude_fields=exclude_fields, indent=indent))
|
||||||
print("----------------------------------------")
|
|
||||||
|
def pprints_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> str:
|
||||||
|
"""String corresponding to pretty-print of the object as a dict, excluding specified fields.
|
||||||
|
|
||||||
|
:param exclude_fields: A sequence of field names to exclude from the output.
|
||||||
|
If None, no fields are excluded.
|
||||||
|
:param indent: The indentation to use when pretty-printing.
|
||||||
|
"""
|
||||||
|
prefix = f"{self.__class__.__name__}\n----------------------------------------\n"
|
||||||
print_dict = asdict(self)
|
print_dict = asdict(self)
|
||||||
exclude_fields = exclude_fields or []
|
exclude_fields = exclude_fields or []
|
||||||
for field in exclude_fields:
|
for field in exclude_fields:
|
||||||
print_dict.pop(field, None)
|
print_dict.pop(field, None)
|
||||||
pprint(print_dict)
|
return prefix + pprint.pformat(print_dict, indent=indent)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user