Experiment: use name attribute during run except if overriden explicitly
Pass override_experiment_name as kwarg in examples
This commit is contained in:
parent
ed12b16d70
commit
7d479af0bb
@ -104,7 +104,7 @@ def main(
|
||||
)
|
||||
|
||||
experiment = builder.build()
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -96,7 +96,7 @@ def main(
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -109,7 +109,7 @@ def main(
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -102,7 +102,7 @@ def main(
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
@ -36,15 +37,13 @@ class JoblibExpLauncher(ExpLauncher):
|
||||
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
||||
self.joblib_cfg.backend = "loky"
|
||||
|
||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
||||
Parallel(**asdict(self.joblib_cfg))(
|
||||
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
|
||||
)
|
||||
def launch(self, experiments: Sequence[Experiment]) -> None:
|
||||
Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments)
|
||||
|
||||
@staticmethod
|
||||
def execute_task(exp: Experiment, name: str):
|
||||
def execute_task(exp: Experiment):
|
||||
try:
|
||||
exp.run(name)
|
||||
exp.run()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
@ -83,7 +83,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -74,7 +74,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -85,7 +85,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -95,7 +95,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -7,7 +7,7 @@ from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher
|
||||
from examples.mujoco.launcher import RegisteredExpLauncher
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import VectorEnvType
|
||||
@ -113,11 +113,10 @@ def main(
|
||||
)
|
||||
|
||||
if run_sequential:
|
||||
for experiment_name, experiment in experiments.items():
|
||||
experiment.run(experiment_name)
|
||||
launcher = RegisteredExpLauncher.sequential.create_launcher()
|
||||
else:
|
||||
launcher = JoblibLauncher(JoblibConfig())
|
||||
launcher.launch(experiments)
|
||||
launcher = RegisteredExpLauncher.joblib.create_launcher()
|
||||
launcher.launch(experiments)
|
||||
|
||||
return log_name
|
||||
|
||||
|
@ -83,7 +83,7 @@ def main(
|
||||
.with_critic_ensemble_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -72,7 +72,7 @@ def main(
|
||||
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -80,7 +80,7 @@ def main(
|
||||
.with_common_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -85,7 +85,7 @@ def main(
|
||||
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -89,7 +89,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run("test")
|
||||
experiment.run(override_experiment_name="test")
|
||||
print(experiment)
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run("test")
|
||||
experiment.run(override_experiment_name="test")
|
||||
print(experiment)
|
||||
|
||||
|
||||
|
@ -205,35 +205,42 @@ class Experiment(ToStringMixin):
|
||||
|
||||
def run(
|
||||
self,
|
||||
experiment_name: str | None = None,
|
||||
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
|
||||
logger_run_id: str | None = None,
|
||||
raise_error_on_dirname_collision: bool = True,
|
||||
) -> ExperimentResult:
|
||||
"""Run the experiment and return the results.
|
||||
|
||||
:param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
||||
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
|
||||
The name corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||
a nested directory structure will be created.
|
||||
If None, use a name containing the current date and time.
|
||||
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
|
||||
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
|
||||
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
||||
using wandb, in particular).
|
||||
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
|
||||
experiment with the same name.
|
||||
:return:
|
||||
"""
|
||||
if experiment_name is None:
|
||||
experiment_name = datetime_tag()
|
||||
if override_experiment_name is not None:
|
||||
if override_experiment_name == "DATETIME_TAG":
|
||||
override_experiment_name = datetime_tag()
|
||||
self.name = override_experiment_name
|
||||
|
||||
# initialize persistence directory
|
||||
use_persistence = self.config.persistence_enabled
|
||||
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name)
|
||||
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
|
||||
if use_persistence:
|
||||
os.makedirs(persistence_dir, exist_ok=True)
|
||||
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
|
||||
|
||||
with logging.FileLoggerContext(
|
||||
os.path.join(persistence_dir, self.LOG_FILENAME),
|
||||
enabled=use_persistence and self.config.log_file_enabled,
|
||||
):
|
||||
# log initial information
|
||||
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
|
||||
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
|
||||
log.info(f"Working directory: {os.getcwd()}")
|
||||
|
||||
self._set_seed()
|
||||
@ -264,7 +271,7 @@ class Experiment(ToStringMixin):
|
||||
if use_persistence:
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_dir=persistence_dir,
|
||||
experiment_name=experiment_name,
|
||||
experiment_name=self.name,
|
||||
run_id=logger_run_id,
|
||||
config_dict=full_config,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user