Experiment: use name attribute during run except if overriden explicitly
Pass override_experiment_name as kwarg in examples
This commit is contained in:
parent
ed12b16d70
commit
7d479af0bb
@ -104,7 +104,7 @@ def main(
|
|||||||
)
|
)
|
||||||
|
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -96,7 +96,7 @@ def main(
|
|||||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -109,7 +109,7 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -102,7 +102,7 @@ def main(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@ -36,15 +37,13 @@ class JoblibExpLauncher(ExpLauncher):
|
|||||||
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
# Joblib's backend is hard-coded to loky since the threading backend produces different results
|
||||||
self.joblib_cfg.backend = "loky"
|
self.joblib_cfg.backend = "loky"
|
||||||
|
|
||||||
def launch(self, experiments: dict[str, Experiment]) -> None:
|
def launch(self, experiments: Sequence[Experiment]) -> None:
|
||||||
Parallel(**asdict(self.joblib_cfg))(
|
Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments)
|
||||||
delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items()
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def execute_task(exp: Experiment, name: str):
|
def execute_task(exp: Experiment):
|
||||||
try:
|
try:
|
||||||
exp.run(name)
|
exp.run()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -74,7 +74,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes)
|
.with_critic_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -85,7 +85,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -95,7 +95,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -7,7 +7,7 @@ from typing import Literal
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher
|
from examples.mujoco.launcher import RegisteredExpLauncher
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import VectorEnvType
|
from tianshou.highlevel.env import VectorEnvType
|
||||||
@ -113,11 +113,10 @@ def main(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if run_sequential:
|
if run_sequential:
|
||||||
for experiment_name, experiment in experiments.items():
|
launcher = RegisteredExpLauncher.sequential.create_launcher()
|
||||||
experiment.run(experiment_name)
|
|
||||||
else:
|
else:
|
||||||
launcher = JoblibLauncher(JoblibConfig())
|
launcher = RegisteredExpLauncher.joblib.create_launcher()
|
||||||
launcher.launch(experiments)
|
launcher.launch(experiments)
|
||||||
|
|
||||||
return log_name
|
return log_name
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ def main(
|
|||||||
.with_critic_ensemble_factory_default(hidden_sizes)
|
.with_critic_ensemble_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -72,7 +72,7 @@ def main(
|
|||||||
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -80,7 +80,7 @@ def main(
|
|||||||
.with_common_critic_factory_default(hidden_sizes)
|
.with_common_critic_factory_default(hidden_sizes)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -85,7 +85,7 @@ def main(
|
|||||||
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -89,7 +89,7 @@ def main(
|
|||||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
experiment.run(log_name)
|
experiment.run(override_experiment_name=log_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
|
|||||||
sampling_config=sampling_config,
|
sampling_config=sampling_config,
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run("test")
|
experiment.run(override_experiment_name="test")
|
||||||
print(experiment)
|
print(experiment)
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
|
|||||||
sampling_config=sampling_config,
|
sampling_config=sampling_config,
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run("test")
|
experiment.run(override_experiment_name="test")
|
||||||
print(experiment)
|
print(experiment)
|
||||||
|
|
||||||
|
|
||||||
|
@ -205,35 +205,42 @@ class Experiment(ToStringMixin):
|
|||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
experiment_name: str | None = None,
|
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
|
||||||
logger_run_id: str | None = None,
|
logger_run_id: str | None = None,
|
||||||
|
raise_error_on_dirname_collision: bool = True,
|
||||||
) -> ExperimentResult:
|
) -> ExperimentResult:
|
||||||
"""Run the experiment and return the results.
|
"""Run the experiment and return the results.
|
||||||
|
|
||||||
:param experiment_name: the experiment name, which corresponds to the directory (within the logging
|
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
|
||||||
|
The name corresponds to the directory (within the logging
|
||||||
directory) where all results associated with the experiment will be saved.
|
directory) where all results associated with the experiment will be saved.
|
||||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||||
a nested directory structure will be created.
|
a nested directory structure will be created.
|
||||||
If None, use a name containing the current date and time.
|
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
|
||||||
|
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
|
||||||
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
||||||
using wandb, in particular).
|
using wandb, in particular).
|
||||||
|
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
|
||||||
|
experiment with the same name.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if experiment_name is None:
|
if override_experiment_name is not None:
|
||||||
experiment_name = datetime_tag()
|
if override_experiment_name == "DATETIME_TAG":
|
||||||
|
override_experiment_name = datetime_tag()
|
||||||
|
self.name = override_experiment_name
|
||||||
|
|
||||||
# initialize persistence directory
|
# initialize persistence directory
|
||||||
use_persistence = self.config.persistence_enabled
|
use_persistence = self.config.persistence_enabled
|
||||||
persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name)
|
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
|
||||||
if use_persistence:
|
if use_persistence:
|
||||||
os.makedirs(persistence_dir, exist_ok=True)
|
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
|
||||||
|
|
||||||
with logging.FileLoggerContext(
|
with logging.FileLoggerContext(
|
||||||
os.path.join(persistence_dir, self.LOG_FILENAME),
|
os.path.join(persistence_dir, self.LOG_FILENAME),
|
||||||
enabled=use_persistence and self.config.log_file_enabled,
|
enabled=use_persistence and self.config.log_file_enabled,
|
||||||
):
|
):
|
||||||
# log initial information
|
# log initial information
|
||||||
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
|
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
|
||||||
log.info(f"Working directory: {os.getcwd()}")
|
log.info(f"Working directory: {os.getcwd()}")
|
||||||
|
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
@ -264,7 +271,7 @@ class Experiment(ToStringMixin):
|
|||||||
if use_persistence:
|
if use_persistence:
|
||||||
logger = self.logger_factory.create_logger(
|
logger = self.logger_factory.create_logger(
|
||||||
log_dir=persistence_dir,
|
log_dir=persistence_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=self.name,
|
||||||
run_id=logger_run_id,
|
run_id=logger_run_id,
|
||||||
config_dict=full_config,
|
config_dict=full_config,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user