From ed12b16d70003dd1bfc826816cada1f3d9511899 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 3 Apr 2024 17:28:38 +0200 Subject: [PATCH] Added contextmanager for ExperimentBuilder modifications Used it to simplify building of seeded experiments --- test/highlevel/test_experiment_builder.py | 27 +++++++ tianshou/highlevel/experiment.py | 98 ++++++++++++++--------- 2 files changed, 88 insertions(+), 37 deletions(-) diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 028c46c..21c79f5 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -79,3 +79,30 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment experiment = builder.build() experiment.run("test") print(experiment) + + +def test_temp_builder_modification(): + env_factory = DiscreteTestEnvFactory() + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) + builder = PPOExperimentBuilder( + experiment_config=ExperimentConfig(persistence_enabled=False), + env_factory=env_factory, + sampling_config=sampling_config, + ) + original_seed = builder.experiment_config.seed + original_train_seed = builder.sampling_config.train_seed + + with builder.temp_config_mutation(): + builder.experiment_config.seed += 12345 + builder.sampling_config.train_seed += 456 + exp = builder.build() + + assert builder.experiment_config.seed == original_seed + assert builder.sampling_config.train_seed == original_train_seed + assert exp.config.seed == original_seed + 12345 + assert exp.sampling_config.train_seed == original_train_seed + 456 diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 8cabbea..431bccb 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,11 +1,12 @@ import os import pickle from abc import abstractmethod -from collections.abc import Sequence +from collections.abc import Iterator, Sequence +from contextlib import contextmanager from copy import copy -from dataclasses import asdict, dataclass +from dataclasses import dataclass from pprint import pformat -from typing import Self +from typing import Literal, Self import numpy as np import torch @@ -145,6 +146,7 @@ class Experiment(ToStringMixin): agent_factory: AgentFactory, sampling_config: SamplingConfig, logger_factory: LoggerFactory | None = None, + name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", ): if logger_factory is None: logger_factory = LoggerFactoryDefault() @@ -153,6 +155,22 @@ class Experiment(ToStringMixin): self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory + if name == "DATETIME_TAG": + name = datetime_tag() + self.name = name + + def get_seeding_info_as_str(self) -> str: + """Useful for creating unique experiment names based on seeds. + + A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. + """ + return ",".join( + [ + f"exp_seed={self.config.seed}", + f"train_seed={self.sampling_config.train_seed}", + f"test_seed={self.sampling_config.test_seed}", + ], + ) @classmethod def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": @@ -338,6 +356,16 @@ class ExperimentBuilder: self._optim_factory: OptimizerFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() + self._experiment_name: str = "" + + @contextmanager + def temp_config_mutation(self) -> Iterator[Self]: + """Returns the builder instance where the configs can be modified without affecting the current instance.""" + original_sampling_config = copy(self.sampling_config) + original_experiment_config = copy(self.experiment_config) + yield self + self.sampling_config = original_sampling_config + self.experiment_config = original_experiment_config @property def experiment_config(self) -> ExperimentConfig: @@ -431,6 +459,20 @@ class ExperimentBuilder: self._trainer_callbacks.epoch_stop_callback = callback return self + def with_experiment_name( + self, + experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", + ) -> Self: + """Sets the name of the experiment. + + :param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used. + :return: the builder + """ + if experiment_name == "DATETIME_TAG": + experiment_name = datetime_tag() + self._experiment_name = experiment_name + return self + @abstractmethod def _create_agent_factory(self) -> AgentFactory: pass @@ -441,9 +483,12 @@ class ExperimentBuilder: else: return self._optim_factory - def build(self) -> Experiment: + def build(self, add_seeding_info_to_name: bool = False) -> Experiment: """Creates the experiment based on the options specified via this builder. + :param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains + info about the training seeds. Useful for creating multiple experiments that only differ + by seeds. :return: the experiment """ agent_factory = self._create_agent_factory() @@ -456,46 +501,25 @@ class ExperimentBuilder: agent_factory, self._sampling_config, self._logger_factory, + name=self._experiment_name, ) + if add_seeding_info_to_name: + experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}" return experiment - def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]: - """Creates a dict of experiments with non-overlapping seeds, starting from the configured seed. + def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]: + """Creates a list of experiments with non-overlapping seeds, starting from the configured seed. - The keys of the dict are the experiment names, which are derived from the seeds used in the experiments. + Each experiment will have a unique name that is created from the original experiment name and the seeds used. """ - configured_experiment_config = copy(self.experiment_config) - configured_experiment_seed = configured_experiment_config.seed - configured_sampling_config = copy(self.sampling_config) - configured_train_seed = configured_sampling_config.train_seed - num_train_envs = configured_sampling_config.num_train_envs + num_train_envs = self.sampling_config.num_train_envs - seeded_experiments = {} + seeded_experiments = [] for i in range(num_experiments): - experiment_seed = configured_experiment_seed + i - new_experiment_config_dict = asdict(configured_experiment_config) - new_experiment_config_dict["seed"] = experiment_seed - - new_train_seed = i * num_train_envs + configured_train_seed - new_sampling_config_dict = asdict(configured_sampling_config) - new_sampling_config_dict["train_seed"] = new_train_seed - - self.experiment_config = ExperimentConfig(**new_experiment_config_dict) - self.sampling_config = SamplingConfig(**new_sampling_config_dict) - exp = self.build() - - experiment_name = ",".join( - [ - f"exp_seed={exp.config.seed}", - f"train_seed={exp.sampling_config.train_seed}", - f"test_seed={exp.sampling_config.test_seed}", - ], - ) - seeded_experiments[experiment_name] = exp - - # restore original config - self.experiment_config = configured_experiment_config - self.sampling_config = configured_sampling_config + with self.temp_config_mutation(): + self.experiment_config.seed += i + self.sampling_config.train_seed += i * num_train_envs + seeded_experiments.append(self.build(add_seeding_info_to_name=True)) return seeded_experiments