Experiment: use name attribute during run except if overriden explicitly

Pass override_experiment_name as kwarg in examples
This commit is contained in:
Michael Panchenko 2024-04-03 17:44:41 +02:00
parent ed12b16d70
commit 7d479af0bb
17 changed files with 40 additions and 35 deletions

View File

@ -104,7 +104,7 @@ def main(
) )
experiment = builder.build() experiment = builder.build()
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -96,7 +96,7 @@ def main(
.with_epoch_stop_callback(AtariEpochStopCallback(task)) .with_epoch_stop_callback(AtariEpochStopCallback(task))
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -109,7 +109,7 @@ def main(
), ),
) )
experiment = builder.build() experiment = builder.build()
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -102,7 +102,7 @@ def main(
), ),
) )
experiment = builder.build() experiment = builder.build()
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Literal from typing import Literal
@ -36,15 +37,13 @@ class JoblibExpLauncher(ExpLauncher):
# Joblib's backend is hard-coded to loky since the threading backend produces different results # Joblib's backend is hard-coded to loky since the threading backend produces different results
self.joblib_cfg.backend = "loky" self.joblib_cfg.backend = "loky"
def launch(self, experiments: dict[str, Experiment]) -> None: def launch(self, experiments: Sequence[Experiment]) -> None:
Parallel(**asdict(self.joblib_cfg))( Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments)
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
)
@staticmethod @staticmethod
def execute_task(exp: Experiment, name: str): def execute_task(exp: Experiment):
try: try:
exp.run(name) exp.run()
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -83,7 +83,7 @@ def main(
.with_critic_factory_default(hidden_sizes, nn.Tanh) .with_critic_factory_default(hidden_sizes, nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -74,7 +74,7 @@ def main(
.with_critic_factory_default(hidden_sizes) .with_critic_factory_default(hidden_sizes)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -95,7 +95,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,7 +7,7 @@ from typing import Literal
import torch import torch
from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher from examples.mujoco.launcher import RegisteredExpLauncher
from examples.mujoco.mujoco_env import MujocoEnvFactory from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType from tianshou.highlevel.env import VectorEnvType
@ -113,11 +113,10 @@ def main(
) )
if run_sequential: if run_sequential:
for experiment_name, experiment in experiments.items(): launcher = RegisteredExpLauncher.sequential.create_launcher()
experiment.run(experiment_name)
else: else:
launcher = JoblibLauncher(JoblibConfig()) launcher = RegisteredExpLauncher.joblib.create_launcher()
launcher.launch(experiments) launcher.launch(experiments)
return log_name return log_name

View File

@ -83,7 +83,7 @@ def main(
.with_critic_ensemble_factory_default(hidden_sizes) .with_critic_ensemble_factory_default(hidden_sizes)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -72,7 +72,7 @@ def main(
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -80,7 +80,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes) .with_common_critic_factory_default(hidden_sizes)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -89,7 +89,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(log_name) experiment.run(override_experiment_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
sampling_config=sampling_config, sampling_config=sampling_config,
) )
experiment = builder.build() experiment = builder.build()
experiment.run("test") experiment.run(override_experiment_name="test")
print(experiment) print(experiment)
@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
sampling_config=sampling_config, sampling_config=sampling_config,
) )
experiment = builder.build() experiment = builder.build()
experiment.run("test") experiment.run(override_experiment_name="test")
print(experiment) print(experiment)

View File

@ -205,35 +205,42 @@ class Experiment(ToStringMixin):
def run( def run(
self, self,
experiment_name: str | None = None, override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
logger_run_id: str | None = None, logger_run_id: str | None = None,
raise_error_on_dirname_collision: bool = True,
) -> ExperimentResult: ) -> ExperimentResult:
"""Run the experiment and return the results. """Run the experiment and return the results.
:param experiment_name: the experiment name, which corresponds to the directory (within the logging :param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
The name corresponds to the directory (within the logging
directory) where all results associated with the experiment will be saved. directory) where all results associated with the experiment will be saved.
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
a nested directory structure will be created. a nested directory structure will be created.
If None, use a name containing the current date and time. If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
using wandb, in particular). using wandb, in particular).
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
experiment with the same name.
:return: :return:
""" """
if experiment_name is None: if override_experiment_name is not None:
experiment_name = datetime_tag() if override_experiment_name == "DATETIME_TAG":
override_experiment_name = datetime_tag()
self.name = override_experiment_name
# initialize persistence directory # initialize persistence directory
use_persistence = self.config.persistence_enabled use_persistence = self.config.persistence_enabled
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name) persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
if use_persistence: if use_persistence:
os.makedirs(persistence_dir, exist_ok=True) os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
with logging.FileLoggerContext( with logging.FileLoggerContext(
os.path.join(persistence_dir, self.LOG_FILENAME), os.path.join(persistence_dir, self.LOG_FILENAME),
enabled=use_persistence and self.config.log_file_enabled, enabled=use_persistence and self.config.log_file_enabled,
): ):
# log initial information # log initial information
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}") log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
log.info(f"Working directory: {os.getcwd()}") log.info(f"Working directory: {os.getcwd()}")
self._set_seed() self._set_seed()
@ -264,7 +271,7 @@ class Experiment(ToStringMixin):
if use_persistence: if use_persistence:
logger = self.logger_factory.create_logger( logger = self.logger_factory.create_logger(
log_dir=persistence_dir, log_dir=persistence_dir,
experiment_name=experiment_name, experiment_name=self.name,
run_id=logger_run_id, run_id=logger_run_id,
config_dict=full_config, config_dict=full_config,
) )