diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 4408a13..319375f 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -14,8 +14,6 @@ These plots are saved in the log directory and displayed in the console. import os import sys -from collections.abc import Sequence -from typing import Literal import torch @@ -41,86 +39,30 @@ log = logging.getLogger(__name__) def main( - experiment_config: ExperimentConfig, - task: str = "Ant-v4", - num_experiments: int = 5, - buffer_size: int = 4096, - hidden_sizes: Sequence[int] = (64, 64), - lr: float = 3e-4, - gamma: float = 0.99, - epoch: int = 3, - step_per_epoch: int = 30000, - step_per_collect: int = 2048, - repeat_per_collect: int = 10, - batch_size: int = 64, - training_num: int = 10, - test_num: int = 10, - rew_norm: bool = True, - vf_coef: float = 0.25, - ent_coef: float = 0.0, - gae_lambda: float = 0.95, - bound_action_method: Literal["clip", "tanh"] | None = "clip", - lr_decay: bool = True, - max_grad_norm: float = 0.5, - eps_clip: float = 0.2, - dual_clip: float | None = None, - value_clip: bool = False, - norm_adv: bool = False, - recompute_adv: bool = True, + num_experiments: int = 2, run_experiments_sequentially: bool = True, -) -> str: - """Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for - a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained, - the results are evaluated using the rliable API. - - :param experiment_config: - :param task: a mujoco task name - :param num_experiments: how many experiments to run with different seeds - :param buffer_size: - :param hidden_sizes: - :param lr: - :param gamma: - :param epoch: - :param step_per_epoch: - :param step_per_collect: - :param repeat_per_collect: - :param batch_size: - :param training_num: - :param test_num: - :param rew_norm: - :param vf_coef: - :param ent_coef: - :param gae_lambda: - :param bound_action_method: - :param lr_decay: - :param max_grad_norm: - :param eps_clip: - :param dual_clip: - :param value_clip: - :param norm_adv: - :param recompute_adv: - :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. +) -> RLiableExperimentResult: + """:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. LIMITATIONS: currently, the parallel execution does not seem to work properly on linux. It might generally be undesired to run multiple experiments in parallel on the same machine, as a single experiment already uses all available CPU cores by default. :return: the directory where the results are stored """ + task = "Ant-v4" persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) - experiment_config.persistence_base_dir = persistence_dir - log.info(f"Will save all experiment results to {persistence_dir}.") - experiment_config.watch = False + experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False) sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, - batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, - num_test_episodes=test_num, - buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + num_epochs=1, + step_per_epoch=5000, + batch_size=64, + num_train_envs=10, + num_test_envs=10, + num_test_episodes=10, + buffer_size=4096, + step_per_collect=2048, + repeat_per_collect=10, ) env_factory = MujocoEnvFactory( @@ -133,52 +75,45 @@ def main( else VectorEnvType.SUBPROC_SHARED_MEM, ) - experiments = ( + hidden_sizes = (64, 64) + + experiment_collection = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ppo_params( PPOParams( - discount_factor=gamma, - gae_lambda=gae_lambda, - action_bound_method=bound_action_method, - reward_normalization=rew_norm, - ent_coef=ent_coef, - vf_coef=vf_coef, - max_grad_norm=max_grad_norm, - value_clip=value_clip, - advantage_normalization=norm_adv, - eps_clip=eps_clip, - dual_clip=dual_clip, - recompute_advantage=recompute_adv, - lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + discount_factor=0.99, + gae_lambda=0.95, + action_bound_method="clip", + reward_normalization=True, + ent_coef=0.0, + vf_coef=0.25, + max_grad_norm=0.5, + value_clip=False, + advantage_normalization=False, + eps_clip=0.2, + dual_clip=None, + recompute_advantage=True, + lr=3e-4, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config), dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_logger_factory(LoggerFactoryDefault("tensorboard")) - .build_default_seeded_experiments(num_experiments) + .build_seeded_collection(num_experiments) ) if run_experiments_sequentially: launcher = RegisteredExpLauncher.sequential.create_launcher() else: launcher = RegisteredExpLauncher.joblib.create_launcher() - launcher.launch(experiments) + experiment_collection.run(launcher) - return persistence_dir - - -def eval_experiments(log_dir: str) -> RLiableExperimentResult: - """Evaluate the experiments in the given log directory using the rliable API.""" - rliable_result = RLiableExperimentResult.load_from_disk(log_dir) + rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir) rliable_result.eval_results(show_plots=True, save_plots=True) return rliable_result if __name__ == "__main__": - log_dir = logging.run_cli(main, level=logging.INFO) - assert isinstance(log_dir, str) # for mypy - evaluation_result = eval_experiments(log_dir) + result = logging.run_cli(main, level=logging.INFO) diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 0ba8a7b..cb52c5a 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -79,30 +79,3 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment experiment = builder.build() experiment.run(run_name="test") print(experiment) - - -def test_temp_builder_modification() -> None: - env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig( - num_epochs=1, - step_per_epoch=100, - num_train_envs=2, - num_test_envs=2, - ) - builder = PPOExperimentBuilder( - experiment_config=ExperimentConfig(persistence_enabled=False), - env_factory=env_factory, - sampling_config=sampling_config, - ) - original_seed = builder.experiment_config.seed - original_train_seed = builder.sampling_config.train_seed - - with builder.temp_config_mutation(): - builder.experiment_config.seed += 12345 - builder.sampling_config.train_seed += 456 - exp = builder.build() - - assert builder.experiment_config.seed == original_seed - assert builder.sampling_config.train_seed == original_train_seed - assert exp.config.seed == original_seed + 12345 - assert exp.sampling_config.train_seed == original_train_seed + 456 diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index e31a392..8fc21cf 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,12 +1,11 @@ import os import pickle from abc import abstractmethod -from collections.abc import Iterator, Sequence -from contextlib import contextmanager -from copy import copy +from collections.abc import Sequence +from copy import deepcopy from dataclasses import dataclass from pprint import pformat -from typing import Self, Dict, Any +from typing import TYPE_CHECKING, Any, Self, Union, cast import numpy as np import torch @@ -85,6 +84,10 @@ from tianshou.utils.logging import datetime_tag from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin +if TYPE_CHECKING: + from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher + + log = logging.getLogger(__name__) @@ -157,19 +160,6 @@ class Experiment(ToStringMixin): self.logger_factory = logger_factory self.name = name - def get_seeding_info_as_str(self) -> str: - """Useful for creating unique experiment names based on seeds. - - A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. - """ - return "_".join( - [ - f"exp_seed={self.config.seed}", - f"train_seed={self.sampling_config.train_seed}", - f"test_seed={self.sampling_config.test_seed}", - ], - ) - @classmethod def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": """Restores an experiment from a previously stored pickle. @@ -184,6 +174,20 @@ class Experiment(ToStringMixin): experiment.config.policy_restore_directory = directory return experiment + def get_seeding_info_as_str(self) -> str: + """Returns information on the seeds used in the experiment as a string. + + This can be useful for creating unique experiment names based on seeds, e.g. + A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. + """ + return "_".join( + [ + f"exp_seed={self.config.seed}", + f"train_seed={self.sampling_config.train_seed}", + f"test_seed={self.sampling_config.test_seed}", + ], + ) + def _set_seed(self) -> None: seed = self.config.seed log.info(f"Setting random seed {seed}") @@ -206,7 +210,7 @@ class Experiment(ToStringMixin): run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> ExperimentResult: """Run the experiment and return the results. @@ -225,7 +229,7 @@ class Experiment(ToStringMixin): # backward compatibility _experiment_name = kwargs.pop("experiment_name", None) if _experiment_name is not None: - run_name = _experiment_name + run_name = cast(str, _experiment_name) deprecation( "Parameter run_name should now be used instead of experiment_name. " "Support for experiment_name will be removed in the future.", @@ -352,6 +356,18 @@ class Experiment(ToStringMixin): ) +class ExperimentCollection: + def __init__(self, experiments: list[Experiment]): + self.experiments = experiments + + def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None: + from tianshou.evaluation.launcher import RegisteredExpLauncher + + if isinstance(launcher, RegisteredExpLauncher): + launcher = launcher.create_launcher() + launcher.launch(experiments=self.experiments) + + class ExperimentBuilder: def __init__( self, @@ -372,14 +388,8 @@ class ExperimentBuilder: self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() - @contextmanager - def temp_config_mutation(self) -> Iterator[Self]: - """Returns the builder instance where the configs can be modified without affecting the current instance.""" - original_sampling_config = copy(self.sampling_config) - original_experiment_config = copy(self.experiment_config) - yield self - self.sampling_config = original_sampling_config - self.experiment_config = original_experiment_config + def copy(self) -> Self: + return deepcopy(self) @property def experiment_config(self) -> ExperimentConfig: @@ -496,12 +506,9 @@ class ExperimentBuilder: else: return self._optim_factory - def build(self, add_seeding_info_to_name: bool = False) -> Experiment: + def build(self) -> Experiment: """Creates the experiment based on the options specified via this builder. - :param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains - info about the training seeds. Useful for creating multiple experiments that only differ - by seeds. :return: the experiment """ agent_factory = self._create_agent_factory() @@ -516,27 +523,24 @@ class ExperimentBuilder: name=self._name, logger_factory=self._logger_factory, ) - if add_seeding_info_to_name: - if not experiment.name: - experiment.name = experiment.get_seeding_info_as_str() - else: - experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}" return experiment - def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]: - """Creates a list of experiments with non-overlapping seeds, starting from the configured seed. + def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: + """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed. - Each experiment will have a unique name that is created from the original experiment name and the seeds used. + Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used. """ num_train_envs = self.sampling_config.num_train_envs seeded_experiments = [] for i in range(num_experiments): - with self.temp_config_mutation(): - self.experiment_config.seed += i - self.sampling_config.train_seed += i * num_train_envs - seeded_experiments.append(self.build(add_seeding_info_to_name=True)) - return seeded_experiments + builder = self.copy() + builder.experiment_config.seed += i + builder.sampling_config.train_seed += i * num_train_envs + experiment = builder.build() + experiment.name += f"_{experiment.get_seeding_info_as_str()}" + seeded_experiments.append(experiment) + return ExperimentCollection(seeded_experiments) class _BuilderMixinActorFactory(ActorFutureProviderProtocol):