Clean up handling of an Experiment's name (and, by extension, a run's name)
This commit is contained in:
parent
a8e9df31f7
commit
35779696ee
@ -104,7 +104,7 @@ def main(
|
|||||||
)
|
)
|
||||||
|
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -96,7 +96,7 @@ def main(
|
|||||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -115,7 +115,7 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -103,7 +103,7 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -83,7 +83,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -74,7 +74,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -85,7 +85,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -95,7 +95,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -83,7 +83,7 @@ def main(
|
|||||||
.with_critic_ensemble_factory_default(hidden_sizes)
|
.with_critic_ensemble_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -72,7 +72,7 @@ def main(
|
|||||||
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -80,7 +80,7 @@ def main(
|
|||||||
.with_common_critic_factory_default(hidden_sizes)
|
.with_common_critic_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -85,7 +85,7 @@ def main(
|
|||||||
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -89,7 +89,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(override_experiment_name=log_name)
|
experiment.run(run_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
|
|||||||
sampling_config=sampling_config,
|
sampling_config=sampling_config,
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(override_experiment_name="test")
|
experiment.run(run_name="test")
|
||||||
print(experiment)
|
print(experiment)
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
|
|||||||
sampling_config=sampling_config,
|
sampling_config=sampling_config,
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(override_experiment_name="test")
|
experiment.run(run_name="test")
|
||||||
print(experiment)
|
print(experiment)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from contextlib import contextmanager
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Literal, Self
|
from typing import Self, Dict, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -80,7 +80,7 @@ from tianshou.highlevel.trainer import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.world import World
|
from tianshou.highlevel.world import World
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.utils import LazyLogger, logging
|
from tianshou.utils import LazyLogger, deprecation, logging
|
||||||
from tianshou.utils.logging import datetime_tag
|
from tianshou.utils.logging import datetime_tag
|
||||||
from tianshou.utils.net.common import ModuleType
|
from tianshou.utils.net.common import ModuleType
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
@ -145,8 +145,8 @@ class Experiment(ToStringMixin):
|
|||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
|
name: str,
|
||||||
logger_factory: LoggerFactory | None = None,
|
logger_factory: LoggerFactory | None = None,
|
||||||
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
|
||||||
):
|
):
|
||||||
if logger_factory is None:
|
if logger_factory is None:
|
||||||
logger_factory = LoggerFactoryDefault()
|
logger_factory = LoggerFactoryDefault()
|
||||||
@ -155,8 +155,6 @@ class Experiment(ToStringMixin):
|
|||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
self.logger_factory = logger_factory
|
self.logger_factory = logger_factory
|
||||||
if name == "DATETIME_TAG":
|
|
||||||
name = datetime_tag()
|
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def get_seeding_info_as_str(self) -> str:
|
def get_seeding_info_as_str(self) -> str:
|
||||||
@ -205,33 +203,41 @@ class Experiment(ToStringMixin):
|
|||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
|
run_name: str | None = None,
|
||||||
logger_run_id: str | None = None,
|
logger_run_id: str | None = None,
|
||||||
raise_error_on_dirname_collision: bool = True,
|
raise_error_on_dirname_collision: bool = True,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
) -> ExperimentResult:
|
) -> ExperimentResult:
|
||||||
"""Run the experiment and return the results.
|
"""Run the experiment and return the results.
|
||||||
|
|
||||||
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
|
:param run_name: Defines a name for this run of the experiment, which determines
|
||||||
The name corresponds to the directory (within the logging
|
the subdirectory (within the persistence base directory) where all results will be saved.
|
||||||
directory) where all results associated with the experiment will be saved.
|
If None, the experiment's name will be used.
|
||||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||||
a nested directory structure will be created.
|
a nested directory structure will be created.
|
||||||
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
|
|
||||||
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
|
|
||||||
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
||||||
using wandb, in particular).
|
using wandb, in particular).
|
||||||
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
|
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
|
||||||
experiment with the same name.
|
experiment with the same name.
|
||||||
|
:param kwargs: for backward compatibility with old parameter names only
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if override_experiment_name is not None:
|
# backward compatibility
|
||||||
if override_experiment_name == "DATETIME_TAG":
|
_experiment_name = kwargs.pop("experiment_name", None)
|
||||||
override_experiment_name = datetime_tag()
|
if _experiment_name is not None:
|
||||||
self.name = override_experiment_name
|
run_name = _experiment_name
|
||||||
|
deprecation(
|
||||||
|
"Parameter run_name should now be used instead of experiment_name. "
|
||||||
|
"Support for experiment_name will be removed in the future.",
|
||||||
|
)
|
||||||
|
assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}"
|
||||||
|
|
||||||
|
if run_name is None:
|
||||||
|
run_name = self.name
|
||||||
|
|
||||||
# initialize persistence directory
|
# initialize persistence directory
|
||||||
use_persistence = self.config.persistence_enabled
|
use_persistence = self.config.persistence_enabled
|
||||||
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
|
persistence_dir = os.path.join(self.config.persistence_base_dir, run_name)
|
||||||
if use_persistence:
|
if use_persistence:
|
||||||
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
|
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
|
||||||
|
|
||||||
@ -240,7 +246,7 @@ class Experiment(ToStringMixin):
|
|||||||
enabled=use_persistence and self.config.log_file_enabled,
|
enabled=use_persistence and self.config.log_file_enabled,
|
||||||
):
|
):
|
||||||
# log initial information
|
# log initial information
|
||||||
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
|
log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}")
|
||||||
log.info(f"Working directory: {os.getcwd()}")
|
log.info(f"Working directory: {os.getcwd()}")
|
||||||
|
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
@ -271,7 +277,7 @@ class Experiment(ToStringMixin):
|
|||||||
if use_persistence:
|
if use_persistence:
|
||||||
logger = self.logger_factory.create_logger(
|
logger = self.logger_factory.create_logger(
|
||||||
log_dir=persistence_dir,
|
log_dir=persistence_dir,
|
||||||
experiment_name=self.name,
|
experiment_name=run_name,
|
||||||
run_id=logger_run_id,
|
run_id=logger_run_id,
|
||||||
config_dict=full_config,
|
config_dict=full_config,
|
||||||
)
|
)
|
||||||
@ -363,7 +369,7 @@ class ExperimentBuilder:
|
|||||||
self._optim_factory: OptimizerFactory | None = None
|
self._optim_factory: OptimizerFactory | None = None
|
||||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
self._experiment_name: str = ""
|
self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def temp_config_mutation(self) -> Iterator[Self]:
|
def temp_config_mutation(self) -> Iterator[Self]:
|
||||||
@ -466,18 +472,17 @@ class ExperimentBuilder:
|
|||||||
self._trainer_callbacks.epoch_stop_callback = callback
|
self._trainer_callbacks.epoch_stop_callback = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_experiment_name(
|
def with_name(
|
||||||
self,
|
self,
|
||||||
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
name: str,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Sets the name of the experiment.
|
"""Sets the name of the experiment.
|
||||||
|
|
||||||
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
|
:param name: the name to use for this experiment, which, when the experiment is run,
|
||||||
|
will determine the storage sub-folder by default
|
||||||
:return: the builder
|
:return: the builder
|
||||||
"""
|
"""
|
||||||
if experiment_name == "DATETIME_TAG":
|
self._name = name
|
||||||
experiment_name = datetime_tag()
|
|
||||||
self._experiment_name = experiment_name
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -503,12 +508,12 @@ class ExperimentBuilder:
|
|||||||
if self._policy_wrapper_factory:
|
if self._policy_wrapper_factory:
|
||||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||||
experiment: Experiment = Experiment(
|
experiment: Experiment = Experiment(
|
||||||
self._config,
|
config=self._config,
|
||||||
self._env_factory,
|
env_factory=self._env_factory,
|
||||||
agent_factory,
|
agent_factory=agent_factory,
|
||||||
self._sampling_config,
|
sampling_config=self._sampling_config,
|
||||||
self._logger_factory,
|
name=self._name,
|
||||||
name=self._experiment_name,
|
logger_factory=self._logger_factory,
|
||||||
)
|
)
|
||||||
if add_seeding_info_to_name:
|
if add_seeding_info_to_name:
|
||||||
if not experiment.name:
|
if not experiment.name:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user