Improvements pertaining to the handling of multi-experiment creation (#1131)

Description of changes: see individual commits; merged without squashing.

Co-authored by: @maxhuettenrauch 

Partly addressed #1129
This commit is contained in:
Michael Panchenko 2024-05-05 21:41:53 +02:00 committed by GitHub
commit 9fbf28ef6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 165 additions and 233 deletions

View File

@ -3,23 +3,38 @@
## Release 1.1.0 ## Release 1.1.0
### Api Extensions ### Api Extensions
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `data`:
- `Collector`s can now be closed, and their reset is more granular. #1063 - `Batch`:
- Trainers can control whether collectors should be reset prior to training. #1063 - Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 - Add methods `to_numpy_` and `to_torch_`. #1098, #1117
- `SamplingConfig` supports `batch_size=None`. #1077 - Add `__eq__` (semantic equality check). #1098
- Batch received new methods: `to_numpy_` and `to_torch_`. #1098, #1117 - `data.collector`:
- `to_dict` in Batch supports also non-recursive conversion. #1098 - `Collector`:
- Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098 - Add method `close` #1063
- Method `reset` is now more granular (new flags controlling behavior). #1063
- `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063
- `trainer`:
- Trainers can now control whether collectors should be reset prior to training. #1063
- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105. - `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105.
- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074 - `highlevel`:
- When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074 - `SamplingConfig`:
- New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074 - Add support for `batch_size=None`. #1077
- `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074 - Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074
- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). - `highlevel.experiment`:
Launchers for parallelization currently in alpha state. #1074 - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
which determines the default run name and therefore the persistence subdirectory.
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
`experiment_name` (although the latter will still be interpreted correctly). #1074 #1131
- Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131
- `ExperimentBuilder`:
- Add method `build_seeded_collection` for the sound creation of multiple
experiments with varying random seeds #1131
- Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074
- The module `evaluation.launchers` for parallelization is currently in alpha state.
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- `continuous.Critic`: - `utils.net`:
- `continuous.Critic`:
- Add flag `apply_preprocess_net_to_obs_only` to allow the - Add flag `apply_preprocess_net_to_obs_only` to allow the
preprocessing network to be applied to the observations only (without preprocessing network to be applied to the observations only (without
the actions concatenated), which is essential for the case where we want the actions concatenated), which is essential for the case where we want

View File

@ -104,7 +104,7 @@ def main(
) )
experiment = builder.build() experiment = builder.build()
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -96,7 +96,7 @@ def main(
.with_epoch_stop_callback(AtariEpochStopCallback(task)) .with_epoch_stop_callback(AtariEpochStopCallback(task))
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -115,7 +115,7 @@ def main(
), ),
) )
experiment = builder.build() experiment = builder.build()
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -103,7 +103,7 @@ def main(
), ),
) )
experiment = builder.build() experiment = builder.build()
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -83,7 +83,7 @@ def main(
.with_critic_factory_default(hidden_sizes, nn.Tanh) .with_critic_factory_default(hidden_sizes, nn.Tanh)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -74,7 +74,7 @@ def main(
.with_critic_factory_default(hidden_sizes) .with_critic_factory_default(hidden_sizes)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -95,7 +95,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -14,8 +14,6 @@ These plots are saved in the log directory and displayed in the console.
import os import os
import sys import sys
from collections.abc import Sequence
from typing import Literal
import torch import torch
@ -41,86 +39,30 @@ log = logging.getLogger(__name__)
def main( def main(
experiment_config: ExperimentConfig, num_experiments: int = 2,
task: str = "Ant-v4",
num_experiments: int = 5,
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 3,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
run_experiments_sequentially: bool = True, run_experiments_sequentially: bool = True,
) -> str: ) -> RLiableExperimentResult:
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for """:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
the results are evaluated using the rliable API.
:param experiment_config:
:param task: a mujoco task name
:param num_experiments: how many experiments to run with different seeds
:param buffer_size:
:param hidden_sizes:
:param lr:
:param gamma:
:param epoch:
:param step_per_epoch:
:param step_per_collect:
:param repeat_per_collect:
:param batch_size:
:param training_num:
:param test_num:
:param rew_norm:
:param vf_coef:
:param ent_coef:
:param gae_lambda:
:param bound_action_method:
:param lr_decay:
:param max_grad_norm:
:param eps_clip:
:param dual_clip:
:param value_clip:
:param norm_adv:
:param recompute_adv:
:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
LIMITATIONS: currently, the parallel execution does not seem to work properly on linux. LIMITATIONS: currently, the parallel execution does not seem to work properly on linux.
It might generally be undesired to run multiple experiments in parallel on the same machine, It might generally be undesired to run multiple experiments in parallel on the same machine,
as a single experiment already uses all available CPU cores by default. as a single experiment already uses all available CPU cores by default.
:return: the directory where the results are stored :return: the directory where the results are stored
""" """
task = "Ant-v4"
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))
experiment_config.persistence_base_dir = persistence_dir experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False)
log.info(f"Will save all experiment results to {persistence_dir}.")
experiment_config.watch = False
sampling_config = SamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=1,
step_per_epoch=step_per_epoch, step_per_epoch=5000,
batch_size=batch_size, batch_size=64,
num_train_envs=training_num, num_train_envs=10,
num_test_envs=test_num, num_test_envs=10,
num_test_episodes=test_num, num_test_episodes=10,
buffer_size=buffer_size, buffer_size=4096,
step_per_collect=step_per_collect, step_per_collect=2048,
repeat_per_collect=repeat_per_collect, repeat_per_collect=10,
) )
env_factory = MujocoEnvFactory( env_factory = MujocoEnvFactory(
@ -133,52 +75,45 @@ def main(
else VectorEnvType.SUBPROC_SHARED_MEM, else VectorEnvType.SUBPROC_SHARED_MEM,
) )
experiments = ( hidden_sizes = (64, 64)
experiment_collection = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config) PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params( .with_ppo_params(
PPOParams( PPOParams(
discount_factor=gamma, discount_factor=0.99,
gae_lambda=gae_lambda, gae_lambda=0.95,
action_bound_method=bound_action_method, action_bound_method="clip",
reward_normalization=rew_norm, reward_normalization=True,
ent_coef=ent_coef, ent_coef=0.0,
vf_coef=vf_coef, vf_coef=0.25,
max_grad_norm=max_grad_norm, max_grad_norm=0.5,
value_clip=value_clip, value_clip=False,
advantage_normalization=norm_adv, advantage_normalization=False,
eps_clip=eps_clip, eps_clip=0.2,
dual_clip=dual_clip, dual_clip=None,
recompute_advantage=recompute_adv, recompute_advantage=True,
lr=lr, lr=3e-4,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config),
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(), dist_fn=DistributionFunctionFactoryIndependentGaussians(),
), ),
) )
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.with_logger_factory(LoggerFactoryDefault("tensorboard")) .with_logger_factory(LoggerFactoryDefault("tensorboard"))
.build_default_seeded_experiments(num_experiments) .build_seeded_collection(num_experiments)
) )
if run_experiments_sequentially: if run_experiments_sequentially:
launcher = RegisteredExpLauncher.sequential.create_launcher() launcher = RegisteredExpLauncher.sequential.create_launcher()
else: else:
launcher = RegisteredExpLauncher.joblib.create_launcher() launcher = RegisteredExpLauncher.joblib.create_launcher()
launcher.launch(experiments) experiment_collection.run(launcher)
return persistence_dir rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
def eval_experiments(log_dir: str) -> RLiableExperimentResult:
"""Evaluate the experiments in the given log directory using the rliable API."""
rliable_result = RLiableExperimentResult.load_from_disk(log_dir)
rliable_result.eval_results(show_plots=True, save_plots=True) rliable_result.eval_results(show_plots=True, save_plots=True)
return rliable_result return rliable_result
if __name__ == "__main__": if __name__ == "__main__":
log_dir = logging.run_cli(main, level=logging.INFO) result = logging.run_cli(main, level=logging.INFO)
assert isinstance(log_dir, str) # for mypy
evaluation_result = eval_experiments(log_dir)

View File

@ -83,7 +83,7 @@ def main(
.with_critic_ensemble_factory_default(hidden_sizes) .with_critic_ensemble_factory_default(hidden_sizes)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -72,7 +72,7 @@ def main(
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -80,7 +80,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes) .with_common_critic_factory_default(hidden_sizes)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -85,7 +85,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -89,7 +89,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build() .build()
) )
experiment.run(override_experiment_name=log_name) experiment.run(run_name=log_name)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
sampling_config=sampling_config, sampling_config=sampling_config,
) )
experiment = builder.build() experiment = builder.build()
experiment.run(override_experiment_name="test") experiment.run(run_name="test")
print(experiment) print(experiment)
@ -77,32 +77,5 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
sampling_config=sampling_config, sampling_config=sampling_config,
) )
experiment = builder.build() experiment = builder.build()
experiment.run(override_experiment_name="test") experiment.run(run_name="test")
print(experiment) print(experiment)
def test_temp_builder_modification() -> None:
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = PPOExperimentBuilder(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)
original_seed = builder.experiment_config.seed
original_train_seed = builder.sampling_config.train_seed
with builder.temp_config_mutation():
builder.experiment_config.seed += 12345
builder.sampling_config.train_seed += 456
exp = builder.build()
assert builder.experiment_config.seed == original_seed
assert builder.sampling_config.train_seed == original_train_seed
assert exp.config.seed == original_seed + 12345
assert exp.sampling_config.train_seed == original_train_seed + 456

View File

@ -1,12 +1,11 @@
import os import os
import pickle import pickle
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterator, Sequence from collections.abc import Sequence
from contextlib import contextmanager from copy import deepcopy
from copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pformat from pprint import pformat
from typing import Literal, Self from typing import TYPE_CHECKING, Any, Self, Union, cast
import numpy as np import numpy as np
import torch import torch
@ -80,11 +79,15 @@ from tianshou.highlevel.trainer import (
) )
from tianshou.highlevel.world import World from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import LazyLogger, logging from tianshou.utils import LazyLogger, deprecation, logging
from tianshou.utils.logging import datetime_tag from tianshou.utils.logging import datetime_tag
from tianshou.utils.net.common import ModuleType from tianshou.utils.net.common import ModuleType
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
if TYPE_CHECKING:
from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -145,8 +148,8 @@ class Experiment(ToStringMixin):
env_factory: EnvFactory, env_factory: EnvFactory,
agent_factory: AgentFactory, agent_factory: AgentFactory,
sampling_config: SamplingConfig, sampling_config: SamplingConfig,
name: str,
logger_factory: LoggerFactory | None = None, logger_factory: LoggerFactory | None = None,
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
): ):
if logger_factory is None: if logger_factory is None:
logger_factory = LoggerFactoryDefault() logger_factory = LoggerFactoryDefault()
@ -155,23 +158,8 @@ class Experiment(ToStringMixin):
self.env_factory = env_factory self.env_factory = env_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
self.logger_factory = logger_factory self.logger_factory = logger_factory
if name == "DATETIME_TAG":
name = datetime_tag()
self.name = name self.name = name
def get_seeding_info_as_str(self) -> str:
"""Useful for creating unique experiment names based on seeds.
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
"""
return "_".join(
[
f"exp_seed={self.config.seed}",
f"train_seed={self.sampling_config.train_seed}",
f"test_seed={self.sampling_config.test_seed}",
],
)
@classmethod @classmethod
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
"""Restores an experiment from a previously stored pickle. """Restores an experiment from a previously stored pickle.
@ -186,6 +174,20 @@ class Experiment(ToStringMixin):
experiment.config.policy_restore_directory = directory experiment.config.policy_restore_directory = directory
return experiment return experiment
def get_seeding_info_as_str(self) -> str:
"""Returns information on the seeds used in the experiment as a string.
This can be useful for creating unique experiment names based on seeds, e.g.
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
"""
return "_".join(
[
f"exp_seed={self.config.seed}",
f"train_seed={self.sampling_config.train_seed}",
f"test_seed={self.sampling_config.test_seed}",
],
)
def _set_seed(self) -> None: def _set_seed(self) -> None:
seed = self.config.seed seed = self.config.seed
log.info(f"Setting random seed {seed}") log.info(f"Setting random seed {seed}")
@ -205,33 +207,41 @@ class Experiment(ToStringMixin):
def run( def run(
self, self,
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None, run_name: str | None = None,
logger_run_id: str | None = None, logger_run_id: str | None = None,
raise_error_on_dirname_collision: bool = True, raise_error_on_dirname_collision: bool = True,
**kwargs: dict[str, Any],
) -> ExperimentResult: ) -> ExperimentResult:
"""Run the experiment and return the results. """Run the experiment and return the results.
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute. :param run_name: Defines a name for this run of the experiment, which determines
The name corresponds to the directory (within the logging the subdirectory (within the persistence base directory) where all results will be saved.
directory) where all results associated with the experiment will be saved. If None, the experiment's name will be used.
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
a nested directory structure will be created. a nested directory structure will be created.
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
using wandb, in particular). using wandb, in particular).
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
experiment with the same name. experiment with the same name.
:param kwargs: for backward compatibility with old parameter names only
:return: :return:
""" """
if override_experiment_name is not None: # backward compatibility
if override_experiment_name == "DATETIME_TAG": _experiment_name = kwargs.pop("experiment_name", None)
override_experiment_name = datetime_tag() if _experiment_name is not None:
self.name = override_experiment_name run_name = cast(str, _experiment_name)
deprecation(
"Parameter run_name should now be used instead of experiment_name. "
"Support for experiment_name will be removed in the future.",
)
assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}"
if run_name is None:
run_name = self.name
# initialize persistence directory # initialize persistence directory
use_persistence = self.config.persistence_enabled use_persistence = self.config.persistence_enabled
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name) persistence_dir = os.path.join(self.config.persistence_base_dir, run_name)
if use_persistence: if use_persistence:
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
@ -240,7 +250,7 @@ class Experiment(ToStringMixin):
enabled=use_persistence and self.config.log_file_enabled, enabled=use_persistence and self.config.log_file_enabled,
): ):
# log initial information # log initial information
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}") log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}")
log.info(f"Working directory: {os.getcwd()}") log.info(f"Working directory: {os.getcwd()}")
self._set_seed() self._set_seed()
@ -271,7 +281,7 @@ class Experiment(ToStringMixin):
if use_persistence: if use_persistence:
logger = self.logger_factory.create_logger( logger = self.logger_factory.create_logger(
log_dir=persistence_dir, log_dir=persistence_dir,
experiment_name=self.name, experiment_name=run_name,
run_id=logger_run_id, run_id=logger_run_id,
config_dict=full_config, config_dict=full_config,
) )
@ -346,6 +356,18 @@ class Experiment(ToStringMixin):
) )
class ExperimentCollection:
def __init__(self, experiments: list[Experiment]):
self.experiments = experiments
def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None:
from tianshou.evaluation.launcher import RegisteredExpLauncher
if isinstance(launcher, RegisteredExpLauncher):
launcher = launcher.create_launcher()
launcher.launch(experiments=self.experiments)
class ExperimentBuilder: class ExperimentBuilder:
def __init__( def __init__(
self, self,
@ -364,16 +386,10 @@ class ExperimentBuilder:
self._optim_factory: OptimizerFactory | None = None self._optim_factory: OptimizerFactory | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
self._experiment_name: str = "" self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag()
@contextmanager def copy(self) -> Self:
def temp_config_mutation(self) -> Iterator[Self]: return deepcopy(self)
"""Returns the builder instance where the configs can be modified without affecting the current instance."""
original_sampling_config = copy(self.sampling_config)
original_experiment_config = copy(self.experiment_config)
yield self
self.sampling_config = original_sampling_config
self.experiment_config = original_experiment_config
@property @property
def experiment_config(self) -> ExperimentConfig: def experiment_config(self) -> ExperimentConfig:
@ -467,18 +483,17 @@ class ExperimentBuilder:
self._trainer_callbacks.epoch_stop_callback = callback self._trainer_callbacks.epoch_stop_callback = callback
return self return self
def with_experiment_name( def with_name(
self, self,
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", name: str,
) -> Self: ) -> Self:
"""Sets the name of the experiment. """Sets the name of the experiment.
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used. :param name: the name to use for this experiment, which, when the experiment is run,
will determine the storage sub-folder by default
:return: the builder :return: the builder
""" """
if experiment_name == "DATETIME_TAG": self._name = name
experiment_name = datetime_tag()
self._experiment_name = experiment_name
return self return self
@abstractmethod @abstractmethod
@ -491,12 +506,9 @@ class ExperimentBuilder:
else: else:
return self._optim_factory return self._optim_factory
def build(self, add_seeding_info_to_name: bool = False) -> Experiment: def build(self) -> Experiment:
"""Creates the experiment based on the options specified via this builder. """Creates the experiment based on the options specified via this builder.
:param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains
info about the training seeds. Useful for creating multiple experiments that only differ
by seeds.
:return: the experiment :return: the experiment
""" """
agent_factory = self._create_agent_factory() agent_factory = self._create_agent_factory()
@ -504,34 +516,31 @@ class ExperimentBuilder:
if self._policy_wrapper_factory: if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
experiment: Experiment = Experiment( experiment: Experiment = Experiment(
self._config, config=self._config,
self._env_factory, env_factory=self._env_factory,
agent_factory, agent_factory=agent_factory,
self._sampling_config, sampling_config=self._sampling_config,
self._logger_factory, name=self._name,
name=self._experiment_name, logger_factory=self._logger_factory,
) )
if add_seeding_info_to_name:
if not experiment.name:
experiment.name = experiment.get_seeding_info_as_str()
else:
experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"
return experiment return experiment
def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]: def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection:
"""Creates a list of experiments with non-overlapping seeds, starting from the configured seed. """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed.
Each experiment will have a unique name that is created from the original experiment name and the seeds used. Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used.
""" """
num_train_envs = self.sampling_config.num_train_envs num_train_envs = self.sampling_config.num_train_envs
seeded_experiments = [] seeded_experiments = []
for i in range(num_experiments): for i in range(num_experiments):
with self.temp_config_mutation(): builder = self.copy()
self.experiment_config.seed += i builder.experiment_config.seed += i
self.sampling_config.train_seed += i * num_train_envs builder.sampling_config.train_seed += i * num_train_envs
seeded_experiments.append(self.build(add_seeding_info_to_name=True)) experiment = builder.build()
return seeded_experiments experiment.name += f"_{experiment.get_seeding_info_as_str()}"
seeded_experiments.append(experiment)
return ExperimentCollection(seeded_experiments)
class _BuilderMixinActorFactory(ActorFutureProviderProtocol): class _BuilderMixinActorFactory(ActorFutureProviderProtocol):