diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b542ea..2c62cca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,11 @@ - `to_dict` in Batch supports also non-recursive conversion. #1098 - Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098 - `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105. -- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074 +- `highlevel.experiment`: + - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and + which determines the default run name and therefore the persistence subdirectory. + It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than + `experiment_name` (although the latter will still be interpreted correctly). #1074 - When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074 - New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074 - `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074 diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 289363e..aa76983 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -104,7 +104,7 @@ def main( ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 850c0ff..23df1cd 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -96,7 +96,7 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index ea45df5..10dcd0a 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -115,7 +115,7 @@ def main( ), ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 8b1bf28..cf09b40 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -103,7 +103,7 @@ def main( ), ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 96ad8c5..bce02e9 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_factory_default(hidden_sizes, nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index a476245..db9c4e3 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -74,7 +74,7 @@ def main( .with_critic_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 2e437ca..ab265a8 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -85,7 +85,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 601b084..27a701b 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -95,7 +95,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 9b4bca7..f523729 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_ensemble_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index a5ec65f..46eb64f 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -72,7 +72,7 @@ def main( .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 9ffa0f4..5ca7318 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -80,7 +80,7 @@ def main( .with_common_critic_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 6adc73d..3a32c7f 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -85,7 +85,7 @@ def main( .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 3af69bd..f54d4c3 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -89,7 +89,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 725d7f7..0ba8a7b 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime sampling_config=sampling_config, ) experiment = builder.build() - experiment.run(override_experiment_name="test") + experiment.run(run_name="test") print(experiment) @@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment sampling_config=sampling_config, ) experiment = builder.build() - experiment.run(override_experiment_name="test") + experiment.run(run_name="test") print(experiment) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 6f9eb7c..e31a392 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from copy import copy from dataclasses import dataclass from pprint import pformat -from typing import Literal, Self +from typing import Self, Dict, Any import numpy as np import torch @@ -80,7 +80,7 @@ from tianshou.highlevel.trainer import ( ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy -from tianshou.utils import LazyLogger, logging +from tianshou.utils import LazyLogger, deprecation, logging from tianshou.utils.logging import datetime_tag from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin @@ -145,8 +145,8 @@ class Experiment(ToStringMixin): env_factory: EnvFactory, agent_factory: AgentFactory, sampling_config: SamplingConfig, + name: str, logger_factory: LoggerFactory | None = None, - name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", ): if logger_factory is None: logger_factory = LoggerFactoryDefault() @@ -155,8 +155,6 @@ class Experiment(ToStringMixin): self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory - if name == "DATETIME_TAG": - name = datetime_tag() self.name = name def get_seeding_info_as_str(self) -> str: @@ -205,33 +203,41 @@ class Experiment(ToStringMixin): def run( self, - override_experiment_name: str | Literal["DATETIME_TAG"] | None = None, + run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, + **kwargs: Dict[str, Any], ) -> ExperimentResult: """Run the experiment and return the results. - :param override_experiment_name: if not None, will adjust the current instance's `name` name attribute. - The name corresponds to the directory (within the logging - directory) where all results associated with the experiment will be saved. + :param run_name: Defines a name for this run of the experiment, which determines + the subdirectory (within the persistence base directory) where all results will be saved. + If None, the experiment's name will be used. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. - If "DATETIME_TAG" is passed, use a name containing the current date and time. This option - is useful for preventing file-name collisions if a single experiment is executed repeatedly. :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when using wandb, in particular). :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed experiment with the same name. + :param kwargs: for backward compatibility with old parameter names only :return: """ - if override_experiment_name is not None: - if override_experiment_name == "DATETIME_TAG": - override_experiment_name = datetime_tag() - self.name = override_experiment_name + # backward compatibility + _experiment_name = kwargs.pop("experiment_name", None) + if _experiment_name is not None: + run_name = _experiment_name + deprecation( + "Parameter run_name should now be used instead of experiment_name. " + "Support for experiment_name will be removed in the future.", + ) + assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" + + if run_name is None: + run_name = self.name # initialize persistence directory use_persistence = self.config.persistence_enabled - persistence_dir = os.path.join(self.config.persistence_base_dir, self.name) + persistence_dir = os.path.join(self.config.persistence_base_dir, run_name) if use_persistence: os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) @@ -240,7 +246,7 @@ class Experiment(ToStringMixin): enabled=use_persistence and self.config.log_file_enabled, ): # log initial information - log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}") + log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() @@ -271,7 +277,7 @@ class Experiment(ToStringMixin): if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, - experiment_name=self.name, + experiment_name=run_name, run_id=logger_run_id, config_dict=full_config, ) @@ -364,7 +370,7 @@ class ExperimentBuilder: self._optim_factory: OptimizerFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() - self._experiment_name: str = "" + self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() @contextmanager def temp_config_mutation(self) -> Iterator[Self]: @@ -467,18 +473,17 @@ class ExperimentBuilder: self._trainer_callbacks.epoch_stop_callback = callback return self - def with_experiment_name( + def with_name( self, - experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", + name: str, ) -> Self: """Sets the name of the experiment. - :param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used. + :param name: the name to use for this experiment, which, when the experiment is run, + will determine the storage sub-folder by default :return: the builder """ - if experiment_name == "DATETIME_TAG": - experiment_name = datetime_tag() - self._experiment_name = experiment_name + self._name = name return self @abstractmethod @@ -504,12 +509,12 @@ class ExperimentBuilder: if self._policy_wrapper_factory: agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) experiment: Experiment = Experiment( - self._config, - self._env_factory, - agent_factory, - self._sampling_config, - self._logger_factory, - name=self._experiment_name, + config=self._config, + env_factory=self._env_factory, + agent_factory=agent_factory, + sampling_config=self._sampling_config, + name=self._name, + logger_factory=self._logger_factory, ) if add_seeding_info_to_name: if not experiment.name: