diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 890d948..289363e 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -104,7 +104,7 @@ def main( ) experiment = builder.build() - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 3918550..850c0ff 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -96,7 +96,7 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index b492b9c..7e8ecf4 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -109,7 +109,7 @@ def main( ), ) experiment = builder.build() - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 35f2538..dd798f0 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -102,7 +102,7 @@ def main( ), ) experiment = builder.build() - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/launcher.py b/examples/mujoco/launcher.py index 7d680f5..5913fcc 100644 --- a/examples/mujoco/launcher.py +++ b/examples/mujoco/launcher.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import asdict, dataclass from enum import Enum from typing import Literal @@ -36,15 +37,13 @@ class JoblibExpLauncher(ExpLauncher): # Joblib's backend is hard-coded to loky since the threading backend produces different results self.joblib_cfg.backend = "loky" - def launch(self, experiments: dict[str, Experiment]) -> None: - Parallel(**asdict(self.joblib_cfg))( - delayed(self.execute_task)(exp, exp_name) for exp_name, exp in experiments.items() - ) + def launch(self, experiments: Sequence[Experiment]) -> None: + Parallel(**asdict(self.joblib_cfg))(delayed(self.execute_task)(exp) for exp in experiments) @staticmethod - def execute_task(exp: Experiment, name: str): + def execute_task(exp: Experiment): try: - exp.run(name) + exp.run() except Exception as e: print(e) diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index ef6ebd1..17d5de5 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_factory_default(hidden_sizes, nn.Tanh) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 6ef3886..a476245 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -74,7 +74,7 @@ def main( .with_critic_factory_default(hidden_sizes) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index c430d86..1b57210 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -85,7 +85,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index ccd4b59..601b084 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -95,7 +95,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index ea90bf5..0bae55c 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -7,7 +7,7 @@ from typing import Literal import torch -from examples.mujoco.joblib_launcher import JoblibConfig, JoblibLauncher +from examples.mujoco.launcher import RegisteredExpLauncher from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import VectorEnvType @@ -113,11 +113,10 @@ def main( ) if run_sequential: - for experiment_name, experiment in experiments.items(): - experiment.run(experiment_name) + launcher = RegisteredExpLauncher.sequential.create_launcher() else: - launcher = JoblibLauncher(JoblibConfig()) - launcher.launch(experiments) + launcher = RegisteredExpLauncher.joblib.create_launcher() + launcher.launch(experiments) return log_name diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 8bff905..9b4bca7 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_ensemble_factory_default(hidden_sizes) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 1de7dec..a5ec65f 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -72,7 +72,7 @@ def main( .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 759e615..9ffa0f4 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -80,7 +80,7 @@ def main( .with_common_critic_factory_default(hidden_sizes) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 326223c..6adc73d 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -85,7 +85,7 @@ def main( .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 6231802..e6a411d 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -89,7 +89,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(log_name) + experiment.run(override_experiment_name=log_name) if __name__ == "__main__": diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 21c79f5..0df32b0 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime sampling_config=sampling_config, ) experiment = builder.build() - experiment.run("test") + experiment.run(override_experiment_name="test") print(experiment) @@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment sampling_config=sampling_config, ) experiment = builder.build() - experiment.run("test") + experiment.run(override_experiment_name="test") print(experiment) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 431bccb..d36d402 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -205,35 +205,42 @@ class Experiment(ToStringMixin): def run( self, - experiment_name: str | None = None, + override_experiment_name: str | Literal["DATETIME_TAG"] | None = None, logger_run_id: str | None = None, + raise_error_on_dirname_collision: bool = True, ) -> ExperimentResult: """Run the experiment and return the results. - :param experiment_name: the experiment name, which corresponds to the directory (within the logging + :param override_experiment_name: if not None, will adjust the current instance's `name` name attribute. + The name corresponds to the directory (within the logging directory) where all results associated with the experiment will be saved. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. - If None, use a name containing the current date and time. + If "DATETIME_TAG" is passed, use a name containing the current date and time. This option + is useful for preventing file-name collisions if a single experiment is executed repeatedly. :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when using wandb, in particular). + :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed + experiment with the same name. :return: """ - if experiment_name is None: - experiment_name = datetime_tag() + if override_experiment_name is not None: + if override_experiment_name == "DATETIME_TAG": + override_experiment_name = datetime_tag() + self.name = override_experiment_name # initialize persistence directory use_persistence = self.config.persistence_enabled - persistence_dir = os.path.join(self.config.persistence_base_dir, experiment_name) + persistence_dir = os.path.join(self.config.persistence_base_dir, self.name) if use_persistence: - os.makedirs(persistence_dir, exist_ok=True) + os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) with logging.FileLoggerContext( os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence and self.config.log_file_enabled, ): # log initial information - log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}") + log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() @@ -264,7 +271,7 @@ class Experiment(ToStringMixin): if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, - experiment_name=experiment_name, + experiment_name=self.name, run_id=logger_run_id, config_dict=full_config, )