Clean up handling of an Experiment's name (and, by extension, a run's name)
This commit is contained in:
parent
a65920fc68
commit
2b1594a1c8
@ -12,7 +12,11 @@
|
||||
- `to_dict` in Batch supports also non-recursive conversion. #1098
|
||||
- Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098
|
||||
- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105.
|
||||
- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074
|
||||
- `highlevel.experiment`:
|
||||
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
|
||||
which determines the default run name and therefore the persistence subdirectory.
|
||||
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
|
||||
`experiment_name` (although the latter will still be interpreted correctly). #1074
|
||||
- When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074
|
||||
- New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074
|
||||
- `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074
|
||||
|
@ -104,7 +104,7 @@ def main(
|
||||
)
|
||||
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -96,7 +96,7 @@ def main(
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -115,7 +115,7 @@ def main(
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -103,7 +103,7 @@ def main(
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -83,7 +83,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -74,7 +74,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -85,7 +85,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -95,7 +95,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -83,7 +83,7 @@ def main(
|
||||
.with_critic_ensemble_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -72,7 +72,7 @@ def main(
|
||||
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -80,7 +80,7 @@ def main(
|
||||
.with_common_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -85,7 +85,7 @@ def main(
|
||||
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -89,7 +89,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name="test")
|
||||
experiment.run(run_name="test")
|
||||
print(experiment)
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name="test")
|
||||
experiment.run(run_name="test")
|
||||
print(experiment)
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Literal, Self
|
||||
from typing import Self, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -80,7 +80,7 @@ from tianshou.highlevel.trainer import (
|
||||
)
|
||||
from tianshou.highlevel.world import World
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import LazyLogger, logging
|
||||
from tianshou.utils import LazyLogger, deprecation, logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
from tianshou.utils.net.common import ModuleType
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
@ -145,8 +145,8 @@ class Experiment(ToStringMixin):
|
||||
env_factory: EnvFactory,
|
||||
agent_factory: AgentFactory,
|
||||
sampling_config: SamplingConfig,
|
||||
name: str,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = LoggerFactoryDefault()
|
||||
@ -155,8 +155,6 @@ class Experiment(ToStringMixin):
|
||||
self.env_factory = env_factory
|
||||
self.agent_factory = agent_factory
|
||||
self.logger_factory = logger_factory
|
||||
if name == "DATETIME_TAG":
|
||||
name = datetime_tag()
|
||||
self.name = name
|
||||
|
||||
def get_seeding_info_as_str(self) -> str:
|
||||
@ -205,33 +203,41 @@ class Experiment(ToStringMixin):
|
||||
|
||||
def run(
|
||||
self,
|
||||
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
|
||||
run_name: str | None = None,
|
||||
logger_run_id: str | None = None,
|
||||
raise_error_on_dirname_collision: bool = True,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> ExperimentResult:
|
||||
"""Run the experiment and return the results.
|
||||
|
||||
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
|
||||
The name corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
:param run_name: Defines a name for this run of the experiment, which determines
|
||||
the subdirectory (within the persistence base directory) where all results will be saved.
|
||||
If None, the experiment's name will be used.
|
||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||
a nested directory structure will be created.
|
||||
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
|
||||
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
|
||||
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
||||
using wandb, in particular).
|
||||
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
|
||||
experiment with the same name.
|
||||
:param kwargs: for backward compatibility with old parameter names only
|
||||
:return:
|
||||
"""
|
||||
if override_experiment_name is not None:
|
||||
if override_experiment_name == "DATETIME_TAG":
|
||||
override_experiment_name = datetime_tag()
|
||||
self.name = override_experiment_name
|
||||
# backward compatibility
|
||||
_experiment_name = kwargs.pop("experiment_name", None)
|
||||
if _experiment_name is not None:
|
||||
run_name = _experiment_name
|
||||
deprecation(
|
||||
"Parameter run_name should now be used instead of experiment_name. "
|
||||
"Support for experiment_name will be removed in the future.",
|
||||
)
|
||||
assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}"
|
||||
|
||||
if run_name is None:
|
||||
run_name = self.name
|
||||
|
||||
# initialize persistence directory
|
||||
use_persistence = self.config.persistence_enabled
|
||||
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
|
||||
persistence_dir = os.path.join(self.config.persistence_base_dir, run_name)
|
||||
if use_persistence:
|
||||
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
|
||||
|
||||
@ -240,7 +246,7 @@ class Experiment(ToStringMixin):
|
||||
enabled=use_persistence and self.config.log_file_enabled,
|
||||
):
|
||||
# log initial information
|
||||
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
|
||||
log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}")
|
||||
log.info(f"Working directory: {os.getcwd()}")
|
||||
|
||||
self._set_seed()
|
||||
@ -271,7 +277,7 @@ class Experiment(ToStringMixin):
|
||||
if use_persistence:
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_dir=persistence_dir,
|
||||
experiment_name=self.name,
|
||||
experiment_name=run_name,
|
||||
run_id=logger_run_id,
|
||||
config_dict=full_config,
|
||||
)
|
||||
@ -364,7 +370,7 @@ class ExperimentBuilder:
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
self._experiment_name: str = ""
|
||||
self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag()
|
||||
|
||||
@contextmanager
|
||||
def temp_config_mutation(self) -> Iterator[Self]:
|
||||
@ -467,18 +473,17 @@ class ExperimentBuilder:
|
||||
self._trainer_callbacks.epoch_stop_callback = callback
|
||||
return self
|
||||
|
||||
def with_experiment_name(
|
||||
def with_name(
|
||||
self,
|
||||
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||
name: str,
|
||||
) -> Self:
|
||||
"""Sets the name of the experiment.
|
||||
|
||||
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
|
||||
:param name: the name to use for this experiment, which, when the experiment is run,
|
||||
will determine the storage sub-folder by default
|
||||
:return: the builder
|
||||
"""
|
||||
if experiment_name == "DATETIME_TAG":
|
||||
experiment_name = datetime_tag()
|
||||
self._experiment_name = experiment_name
|
||||
self._name = name
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
@ -504,12 +509,12 @@ class ExperimentBuilder:
|
||||
if self._policy_wrapper_factory:
|
||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||
experiment: Experiment = Experiment(
|
||||
self._config,
|
||||
self._env_factory,
|
||||
agent_factory,
|
||||
self._sampling_config,
|
||||
self._logger_factory,
|
||||
name=self._experiment_name,
|
||||
config=self._config,
|
||||
env_factory=self._env_factory,
|
||||
agent_factory=agent_factory,
|
||||
sampling_config=self._sampling_config,
|
||||
name=self._name,
|
||||
logger_factory=self._logger_factory,
|
||||
)
|
||||
if add_seeding_info_to_name:
|
||||
if not experiment.name:
|
||||
|
Loading…
x
Reference in New Issue
Block a user