Clean up handling of an Experiment's name (and, by extension, a run's name)

This commit is contained in:
Dominik Jain 2024-04-30 16:12:43 +02:00
parent a65920fc68
commit 2b1594a1c8
16 changed files with 56 additions and 47 deletions

View File

@ -12,7 +12,11 @@
- `to_dict` in Batch supports also non-recursive conversion. #1098
- Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098
- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105.
- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074
- `highlevel.experiment`:
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
which determines the default run name and therefore the persistence subdirectory.
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
`experiment_name` (although the latter will still be interpreted correctly). #1074
- When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074
- New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074
- `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074

View File

@ -104,7 +104,7 @@ def main(
)
experiment = builder.build()
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -96,7 +96,7 @@ def main(
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -115,7 +115,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -103,7 +103,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -83,7 +83,7 @@ def main(
.with_critic_factory_default(hidden_sizes, nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -74,7 +74,7 @@ def main(
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -95,7 +95,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -83,7 +83,7 @@ def main(
.with_critic_ensemble_factory_default(hidden_sizes)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -72,7 +72,7 @@ def main(
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -80,7 +80,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -89,7 +89,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)
if __name__ == "__main__":

View File

@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run(override_experiment_name="test")
experiment.run(run_name="test")
print(experiment)
@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run(override_experiment_name="test")
experiment.run(run_name="test")
print(experiment)

View File

@ -6,7 +6,7 @@ from contextlib import contextmanager
from copy import copy
from dataclasses import dataclass
from pprint import pformat
from typing import Literal, Self
from typing import Self, Dict, Any
import numpy as np
import torch
@ -80,7 +80,7 @@ from tianshou.highlevel.trainer import (
)
from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy
from tianshou.utils import LazyLogger, logging
from tianshou.utils import LazyLogger, deprecation, logging
from tianshou.utils.logging import datetime_tag
from tianshou.utils.net.common import ModuleType
from tianshou.utils.string import ToStringMixin
@ -145,8 +145,8 @@ class Experiment(ToStringMixin):
env_factory: EnvFactory,
agent_factory: AgentFactory,
sampling_config: SamplingConfig,
name: str,
logger_factory: LoggerFactory | None = None,
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
):
if logger_factory is None:
logger_factory = LoggerFactoryDefault()
@ -155,8 +155,6 @@ class Experiment(ToStringMixin):
self.env_factory = env_factory
self.agent_factory = agent_factory
self.logger_factory = logger_factory
if name == "DATETIME_TAG":
name = datetime_tag()
self.name = name
def get_seeding_info_as_str(self) -> str:
@ -205,33 +203,41 @@ class Experiment(ToStringMixin):
def run(
self,
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
run_name: str | None = None,
logger_run_id: str | None = None,
raise_error_on_dirname_collision: bool = True,
**kwargs: Dict[str, Any],
) -> ExperimentResult:
"""Run the experiment and return the results.
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
The name corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved.
:param run_name: Defines a name for this run of the experiment, which determines
the subdirectory (within the persistence base directory) where all results will be saved.
If None, the experiment's name will be used.
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
a nested directory structure will be created.
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
using wandb, in particular).
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
experiment with the same name.
:param kwargs: for backward compatibility with old parameter names only
:return:
"""
if override_experiment_name is not None:
if override_experiment_name == "DATETIME_TAG":
override_experiment_name = datetime_tag()
self.name = override_experiment_name
# backward compatibility
_experiment_name = kwargs.pop("experiment_name", None)
if _experiment_name is not None:
run_name = _experiment_name
deprecation(
"Parameter run_name should now be used instead of experiment_name. "
"Support for experiment_name will be removed in the future.",
)
assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}"
if run_name is None:
run_name = self.name
# initialize persistence directory
use_persistence = self.config.persistence_enabled
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
persistence_dir = os.path.join(self.config.persistence_base_dir, run_name)
if use_persistence:
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
@ -240,7 +246,7 @@ class Experiment(ToStringMixin):
enabled=use_persistence and self.config.log_file_enabled,
):
# log initial information
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}")
log.info(f"Working directory: {os.getcwd()}")
self._set_seed()
@ -271,7 +277,7 @@ class Experiment(ToStringMixin):
if use_persistence:
logger = self.logger_factory.create_logger(
log_dir=persistence_dir,
experiment_name=self.name,
experiment_name=run_name,
run_id=logger_run_id,
config_dict=full_config,
)
@ -364,7 +370,7 @@ class ExperimentBuilder:
self._optim_factory: OptimizerFactory | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
self._experiment_name: str = ""
self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag()
@contextmanager
def temp_config_mutation(self) -> Iterator[Self]:
@ -467,18 +473,17 @@ class ExperimentBuilder:
self._trainer_callbacks.epoch_stop_callback = callback
return self
def with_experiment_name(
def with_name(
self,
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
name: str,
) -> Self:
"""Sets the name of the experiment.
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
:param name: the name to use for this experiment, which, when the experiment is run,
will determine the storage sub-folder by default
:return: the builder
"""
if experiment_name == "DATETIME_TAG":
experiment_name = datetime_tag()
self._experiment_name = experiment_name
self._name = name
return self
@abstractmethod
@ -504,12 +509,12 @@ class ExperimentBuilder:
if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
experiment: Experiment = Experiment(
self._config,
self._env_factory,
agent_factory,
self._sampling_config,
self._logger_factory,
name=self._experiment_name,
config=self._config,
env_factory=self._env_factory,
agent_factory=agent_factory,
sampling_config=self._sampling_config,
name=self._name,
logger_factory=self._logger_factory,
)
if add_seeding_info_to_name:
if not experiment.name: