Experiment: use name attribute during run except if overriden explicitly

Pass override_experiment_name as kwarg in examples
This commit is contained in:
Michael Panchenko 2024-04-03 17:44:41 +02:00
parent ed12b16d70
commit 7d479af0bb
17 changed files with 40 additions and 35 deletions

View File

@ -104,7 +104,7 @@ def main(
)
experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -96,7 +96,7 @@ def main(
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -109,7 +109,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -102,7 +102,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Literal
@ -36,15 +37,13 @@ class JoblibExpLauncher(ExpLauncher):
# Joblib's backend is hard-coded to loky since the threading backend produces different results
self.joblib_cfg.backend = "loky"
def launch(self, experiments: dict[str, Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))(
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
)
def launch(self, experiments: Sequence[Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments)
@staticmethod
def execute_task(exp: Experiment, name: str):
def execute_task(exp: Experiment):
try:
exp.run(name)
exp.run()
except Exception as e:
print(e)

View File

@ -83,7 +83,7 @@ def main(
.with_critic_factory_default(hidden_sizes, nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -74,7 +74,7 @@ def main(
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -95,7 +95,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -7,7 +7,7 @@ from typing import Literal
import torch
from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher
from examples.mujoco.launcher import RegisteredExpLauncher
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
@ -113,11 +113,10 @@ def main(
)
if run_sequential:
for experiment_name, experiment in experiments.items():
experiment.run(experiment_name)
launcher = RegisteredExpLauncher.sequential.create_launcher()
else:
launcher = JoblibLauncher(JoblibConfig())
launcher.launch(experiments)
launcher = RegisteredExpLauncher.joblib.create_launcher()
launcher.launch(experiments)
return log_name

View File

@ -83,7 +83,7 @@ def main(
.with_critic_ensemble_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -72,7 +72,7 @@ def main(
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -80,7 +80,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -89,7 +89,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)
if __name__ == "__main__":

View File

@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
experiment.run(override_experiment_name="test")
print(experiment)
@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
sampling_config=sampling_config,
)
experiment = builder.build()
experiment.run("test")
experiment.run(override_experiment_name="test")
print(experiment)

View File

@ -205,35 +205,42 @@ class Experiment(ToStringMixin):
def run(
self,
experiment_name: str | None = None,
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
logger_run_id: str | None = None,
raise_error_on_dirname_collision: bool = True,
) -> ExperimentResult:
"""Run the experiment and return the results.
:param experiment_name: the experiment name, which corresponds to the directory (within the logging
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
The name corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved.
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
a nested directory structure will be created.
If None, use a name containing the current date and time.
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
using wandb, in particular).
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
experiment with the same name.
:return:
"""
if experiment_name is None:
experiment_name = datetime_tag()
if override_experiment_name is not None:
if override_experiment_name == "DATETIME_TAG":
override_experiment_name = datetime_tag()
self.name = override_experiment_name
# initialize persistence directory
use_persistence = self.config.persistence_enabled
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name)
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
if use_persistence:
os.makedirs(persistence_dir, exist_ok=True)
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
with logging.FileLoggerContext(
os.path.join(persistence_dir, self.LOG_FILENAME),
enabled=use_persistence and self.config.log_file_enabled,
):
# log initial information
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
log.info(f"Working directory: {os.getcwd()}")
self._set_seed()
@ -264,7 +271,7 @@ class Experiment(ToStringMixin):
if use_persistence:
logger = self.logger_factory.create_logger(
log_dir=persistence_dir,
experiment_name=experiment_name,
experiment_name=self.name,
run_id=logger_run_id,
config_dict=full_config,
)