Improvements pertaining to the handling of multi-experiment creation (#1131)
Description of changes: see individual commits; merged without squashing. Co-authored by: @maxhuettenrauch Partly addressed #1129
This commit is contained in:
commit
9fbf28ef6e
53
CHANGELOG.md
53
CHANGELOG.md
@ -3,27 +3,42 @@
|
||||
## Release 1.1.0
|
||||
|
||||
### Api Extensions
|
||||
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
|
||||
- `Collector`s can now be closed, and their reset is more granular. #1063
|
||||
- Trainers can control whether collectors should be reset prior to training. #1063
|
||||
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
|
||||
- `SamplingConfig` supports `batch_size=None`. #1077
|
||||
- Batch received new methods: `to_numpy_` and `to_torch_`. #1098, #1117
|
||||
- `to_dict` in Batch supports also non-recursive conversion. #1098
|
||||
- Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098
|
||||
- `data`:
|
||||
- `Batch`:
|
||||
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
|
||||
- Add methods `to_numpy_` and `to_torch_`. #1098, #1117
|
||||
- Add `__eq__` (semantic equality check). #1098
|
||||
- `data.collector`:
|
||||
- `Collector`:
|
||||
- Add method `close` #1063
|
||||
- Method `reset` is now more granular (new flags controlling behavior). #1063
|
||||
- `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063
|
||||
- `trainer`:
|
||||
- Trainers can now control whether collectors should be reset prior to training. #1063
|
||||
- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105.
|
||||
- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074
|
||||
- When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074
|
||||
- New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074
|
||||
- `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074
|
||||
- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!).
|
||||
Launchers for parallelization currently in alpha state. #1074
|
||||
- `highlevel`:
|
||||
- `SamplingConfig`:
|
||||
- Add support for `batch_size=None`. #1077
|
||||
- Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074
|
||||
- `highlevel.experiment`:
|
||||
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
|
||||
which determines the default run name and therefore the persistence subdirectory.
|
||||
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
|
||||
`experiment_name` (although the latter will still be interpreted correctly). #1074 #1131
|
||||
- Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131
|
||||
- `ExperimentBuilder`:
|
||||
- Add method `build_seeded_collection` for the sound creation of multiple
|
||||
experiments with varying random seeds #1131
|
||||
- Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131
|
||||
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074
|
||||
- The module `evaluation.launchers` for parallelization is currently in alpha state.
|
||||
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
|
||||
- `continuous.Critic`:
|
||||
- Add flag `apply_preprocess_net_to_obs_only` to allow the
|
||||
preprocessing network to be applied to the observations only (without
|
||||
the actions concatenated), which is essential for the case where we want
|
||||
to reuse the actor's preprocessing network #1128
|
||||
- `utils.net`:
|
||||
- `continuous.Critic`:
|
||||
- Add flag `apply_preprocess_net_to_obs_only` to allow the
|
||||
preprocessing network to be applied to the observations only (without
|
||||
the actions concatenated), which is essential for the case where we want
|
||||
to reuse the actor's preprocessing network #1128
|
||||
|
||||
### Fixes
|
||||
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
|
||||
|
||||
@ -104,7 +104,7 @@ def main(
|
||||
)
|
||||
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -96,7 +96,7 @@ def main(
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -115,7 +115,7 @@ def main(
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -103,7 +103,7 @@ def main(
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -83,7 +83,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -74,7 +74,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -85,7 +85,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -95,7 +95,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -14,8 +14,6 @@ These plots are saved in the log directory and displayed in the console.
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
@ -41,86 +39,30 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v4",
|
||||
num_experiments: int = 5,
|
||||
buffer_size: int = 4096,
|
||||
hidden_sizes: Sequence[int] = (64, 64),
|
||||
lr: float = 3e-4,
|
||||
gamma: float = 0.99,
|
||||
epoch: int = 3,
|
||||
step_per_epoch: int = 30000,
|
||||
step_per_collect: int = 2048,
|
||||
repeat_per_collect: int = 10,
|
||||
batch_size: int = 64,
|
||||
training_num: int = 10,
|
||||
test_num: int = 10,
|
||||
rew_norm: bool = True,
|
||||
vf_coef: float = 0.25,
|
||||
ent_coef: float = 0.0,
|
||||
gae_lambda: float = 0.95,
|
||||
bound_action_method: Literal["clip", "tanh"] | None = "clip",
|
||||
lr_decay: bool = True,
|
||||
max_grad_norm: float = 0.5,
|
||||
eps_clip: float = 0.2,
|
||||
dual_clip: float | None = None,
|
||||
value_clip: bool = False,
|
||||
norm_adv: bool = False,
|
||||
recompute_adv: bool = True,
|
||||
num_experiments: int = 2,
|
||||
run_experiments_sequentially: bool = True,
|
||||
) -> str:
|
||||
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for
|
||||
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
|
||||
the results are evaluated using the rliable API.
|
||||
|
||||
:param experiment_config:
|
||||
:param task: a mujoco task name
|
||||
:param num_experiments: how many experiments to run with different seeds
|
||||
:param buffer_size:
|
||||
:param hidden_sizes:
|
||||
:param lr:
|
||||
:param gamma:
|
||||
:param epoch:
|
||||
:param step_per_epoch:
|
||||
:param step_per_collect:
|
||||
:param repeat_per_collect:
|
||||
:param batch_size:
|
||||
:param training_num:
|
||||
:param test_num:
|
||||
:param rew_norm:
|
||||
:param vf_coef:
|
||||
:param ent_coef:
|
||||
:param gae_lambda:
|
||||
:param bound_action_method:
|
||||
:param lr_decay:
|
||||
:param max_grad_norm:
|
||||
:param eps_clip:
|
||||
:param dual_clip:
|
||||
:param value_clip:
|
||||
:param norm_adv:
|
||||
:param recompute_adv:
|
||||
:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
|
||||
) -> RLiableExperimentResult:
|
||||
""":param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
|
||||
LIMITATIONS: currently, the parallel execution does not seem to work properly on linux.
|
||||
It might generally be undesired to run multiple experiments in parallel on the same machine,
|
||||
as a single experiment already uses all available CPU cores by default.
|
||||
:return: the directory where the results are stored
|
||||
"""
|
||||
task = "Ant-v4"
|
||||
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))
|
||||
|
||||
experiment_config.persistence_base_dir = persistence_dir
|
||||
log.info(f"Will save all experiment results to {persistence_dir}.")
|
||||
experiment_config.watch = False
|
||||
experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False)
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
num_test_episodes=test_num,
|
||||
buffer_size=buffer_size,
|
||||
step_per_collect=step_per_collect,
|
||||
repeat_per_collect=repeat_per_collect,
|
||||
num_epochs=1,
|
||||
step_per_epoch=5000,
|
||||
batch_size=64,
|
||||
num_train_envs=10,
|
||||
num_test_envs=10,
|
||||
num_test_episodes=10,
|
||||
buffer_size=4096,
|
||||
step_per_collect=2048,
|
||||
repeat_per_collect=10,
|
||||
)
|
||||
|
||||
env_factory = MujocoEnvFactory(
|
||||
@ -133,52 +75,45 @@ def main(
|
||||
else VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
)
|
||||
|
||||
experiments = (
|
||||
hidden_sizes = (64, 64)
|
||||
|
||||
experiment_collection = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_ppo_params(
|
||||
PPOParams(
|
||||
discount_factor=gamma,
|
||||
gae_lambda=gae_lambda,
|
||||
action_bound_method=bound_action_method,
|
||||
reward_normalization=rew_norm,
|
||||
ent_coef=ent_coef,
|
||||
vf_coef=vf_coef,
|
||||
max_grad_norm=max_grad_norm,
|
||||
value_clip=value_clip,
|
||||
advantage_normalization=norm_adv,
|
||||
eps_clip=eps_clip,
|
||||
dual_clip=dual_clip,
|
||||
recompute_advantage=recompute_adv,
|
||||
lr=lr,
|
||||
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||
if lr_decay
|
||||
else None,
|
||||
discount_factor=0.99,
|
||||
gae_lambda=0.95,
|
||||
action_bound_method="clip",
|
||||
reward_normalization=True,
|
||||
ent_coef=0.0,
|
||||
vf_coef=0.25,
|
||||
max_grad_norm=0.5,
|
||||
value_clip=False,
|
||||
advantage_normalization=False,
|
||||
eps_clip=0.2,
|
||||
dual_clip=None,
|
||||
recompute_advantage=True,
|
||||
lr=3e-4,
|
||||
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config),
|
||||
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||
),
|
||||
)
|
||||
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.with_logger_factory(LoggerFactoryDefault("tensorboard"))
|
||||
.build_default_seeded_experiments(num_experiments)
|
||||
.build_seeded_collection(num_experiments)
|
||||
)
|
||||
|
||||
if run_experiments_sequentially:
|
||||
launcher = RegisteredExpLauncher.sequential.create_launcher()
|
||||
else:
|
||||
launcher = RegisteredExpLauncher.joblib.create_launcher()
|
||||
launcher.launch(experiments)
|
||||
experiment_collection.run(launcher)
|
||||
|
||||
return persistence_dir
|
||||
|
||||
|
||||
def eval_experiments(log_dir: str) -> RLiableExperimentResult:
|
||||
"""Evaluate the experiments in the given log directory using the rliable API."""
|
||||
rliable_result = RLiableExperimentResult.load_from_disk(log_dir)
|
||||
rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
|
||||
rliable_result.eval_results(show_plots=True, save_plots=True)
|
||||
return rliable_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
log_dir = logging.run_cli(main, level=logging.INFO)
|
||||
assert isinstance(log_dir, str) # for mypy
|
||||
evaluation_result = eval_experiments(log_dir)
|
||||
result = logging.run_cli(main, level=logging.INFO)
|
||||
|
||||
@ -83,7 +83,7 @@ def main(
|
||||
.with_critic_ensemble_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -72,7 +72,7 @@ def main(
|
||||
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -80,7 +80,7 @@ def main(
|
||||
.with_common_critic_factory_default(hidden_sizes)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -85,7 +85,7 @@ def main(
|
||||
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -89,7 +89,7 @@ def main(
|
||||
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
|
||||
.build()
|
||||
)
|
||||
experiment.run(override_experiment_name=log_name)
|
||||
experiment.run(run_name=log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name="test")
|
||||
experiment.run(run_name="test")
|
||||
print(experiment)
|
||||
|
||||
|
||||
@ -77,32 +77,5 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
experiment = builder.build()
|
||||
experiment.run(override_experiment_name="test")
|
||||
experiment.run(run_name="test")
|
||||
print(experiment)
|
||||
|
||||
|
||||
def test_temp_builder_modification() -> None:
|
||||
env_factory = DiscreteTestEnvFactory()
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=1,
|
||||
step_per_epoch=100,
|
||||
num_train_envs=2,
|
||||
num_test_envs=2,
|
||||
)
|
||||
builder = PPOExperimentBuilder(
|
||||
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||
env_factory=env_factory,
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
original_seed = builder.experiment_config.seed
|
||||
original_train_seed = builder.sampling_config.train_seed
|
||||
|
||||
with builder.temp_config_mutation():
|
||||
builder.experiment_config.seed += 12345
|
||||
builder.sampling_config.train_seed += 456
|
||||
exp = builder.build()
|
||||
|
||||
assert builder.experiment_config.seed == original_seed
|
||||
assert builder.sampling_config.train_seed == original_train_seed
|
||||
assert exp.config.seed == original_seed + 12345
|
||||
assert exp.sampling_config.train_seed == original_train_seed + 456
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
import os
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Literal, Self
|
||||
from typing import TYPE_CHECKING, Any, Self, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -80,11 +79,15 @@ from tianshou.highlevel.trainer import (
|
||||
)
|
||||
from tianshou.highlevel.world import World
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import LazyLogger, logging
|
||||
from tianshou.utils import LazyLogger, deprecation, logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
from tianshou.utils.net.common import ModuleType
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -145,8 +148,8 @@ class Experiment(ToStringMixin):
|
||||
env_factory: EnvFactory,
|
||||
agent_factory: AgentFactory,
|
||||
sampling_config: SamplingConfig,
|
||||
name: str,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = LoggerFactoryDefault()
|
||||
@ -155,23 +158,8 @@ class Experiment(ToStringMixin):
|
||||
self.env_factory = env_factory
|
||||
self.agent_factory = agent_factory
|
||||
self.logger_factory = logger_factory
|
||||
if name == "DATETIME_TAG":
|
||||
name = datetime_tag()
|
||||
self.name = name
|
||||
|
||||
def get_seeding_info_as_str(self) -> str:
|
||||
"""Useful for creating unique experiment names based on seeds.
|
||||
|
||||
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
|
||||
"""
|
||||
return "_".join(
|
||||
[
|
||||
f"exp_seed={self.config.seed}",
|
||||
f"train_seed={self.sampling_config.train_seed}",
|
||||
f"test_seed={self.sampling_config.test_seed}",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
|
||||
"""Restores an experiment from a previously stored pickle.
|
||||
@ -186,6 +174,20 @@ class Experiment(ToStringMixin):
|
||||
experiment.config.policy_restore_directory = directory
|
||||
return experiment
|
||||
|
||||
def get_seeding_info_as_str(self) -> str:
|
||||
"""Returns information on the seeds used in the experiment as a string.
|
||||
|
||||
This can be useful for creating unique experiment names based on seeds, e.g.
|
||||
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
|
||||
"""
|
||||
return "_".join(
|
||||
[
|
||||
f"exp_seed={self.config.seed}",
|
||||
f"train_seed={self.sampling_config.train_seed}",
|
||||
f"test_seed={self.sampling_config.test_seed}",
|
||||
],
|
||||
)
|
||||
|
||||
def _set_seed(self) -> None:
|
||||
seed = self.config.seed
|
||||
log.info(f"Setting random seed {seed}")
|
||||
@ -205,33 +207,41 @@ class Experiment(ToStringMixin):
|
||||
|
||||
def run(
|
||||
self,
|
||||
override_experiment_name: str | Literal["DATETIME_TAG"] | None = None,
|
||||
run_name: str | None = None,
|
||||
logger_run_id: str | None = None,
|
||||
raise_error_on_dirname_collision: bool = True,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> ExperimentResult:
|
||||
"""Run the experiment and return the results.
|
||||
|
||||
:param override_experiment_name: if not None, will adjust the current instance's `name` name attribute.
|
||||
The name corresponds to the directory (within the logging
|
||||
directory) where all results associated with the experiment will be saved.
|
||||
:param run_name: Defines a name for this run of the experiment, which determines
|
||||
the subdirectory (within the persistence base directory) where all results will be saved.
|
||||
If None, the experiment's name will be used.
|
||||
The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case
|
||||
a nested directory structure will be created.
|
||||
If "DATETIME_TAG" is passed, use a name containing the current date and time. This option
|
||||
is useful for preventing file-name collisions if a single experiment is executed repeatedly.
|
||||
:param logger_run_id: Run identifier to use for logger initialization/resumption (applies when
|
||||
using wandb, in particular).
|
||||
:param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed
|
||||
experiment with the same name.
|
||||
:param kwargs: for backward compatibility with old parameter names only
|
||||
:return:
|
||||
"""
|
||||
if override_experiment_name is not None:
|
||||
if override_experiment_name == "DATETIME_TAG":
|
||||
override_experiment_name = datetime_tag()
|
||||
self.name = override_experiment_name
|
||||
# backward compatibility
|
||||
_experiment_name = kwargs.pop("experiment_name", None)
|
||||
if _experiment_name is not None:
|
||||
run_name = cast(str, _experiment_name)
|
||||
deprecation(
|
||||
"Parameter run_name should now be used instead of experiment_name. "
|
||||
"Support for experiment_name will be removed in the future.",
|
||||
)
|
||||
assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}"
|
||||
|
||||
if run_name is None:
|
||||
run_name = self.name
|
||||
|
||||
# initialize persistence directory
|
||||
use_persistence = self.config.persistence_enabled
|
||||
persistence_dir = os.path.join(self.config.persistence_base_dir, self.name)
|
||||
persistence_dir = os.path.join(self.config.persistence_base_dir, run_name)
|
||||
if use_persistence:
|
||||
os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision)
|
||||
|
||||
@ -240,7 +250,7 @@ class Experiment(ToStringMixin):
|
||||
enabled=use_persistence and self.config.log_file_enabled,
|
||||
):
|
||||
# log initial information
|
||||
log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}")
|
||||
log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}")
|
||||
log.info(f"Working directory: {os.getcwd()}")
|
||||
|
||||
self._set_seed()
|
||||
@ -271,7 +281,7 @@ class Experiment(ToStringMixin):
|
||||
if use_persistence:
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_dir=persistence_dir,
|
||||
experiment_name=self.name,
|
||||
experiment_name=run_name,
|
||||
run_id=logger_run_id,
|
||||
config_dict=full_config,
|
||||
)
|
||||
@ -346,6 +356,18 @@ class Experiment(ToStringMixin):
|
||||
)
|
||||
|
||||
|
||||
class ExperimentCollection:
|
||||
def __init__(self, experiments: list[Experiment]):
|
||||
self.experiments = experiments
|
||||
|
||||
def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None:
|
||||
from tianshou.evaluation.launcher import RegisteredExpLauncher
|
||||
|
||||
if isinstance(launcher, RegisteredExpLauncher):
|
||||
launcher = launcher.create_launcher()
|
||||
launcher.launch(experiments=self.experiments)
|
||||
|
||||
|
||||
class ExperimentBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
@ -364,16 +386,10 @@ class ExperimentBuilder:
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
self._experiment_name: str = ""
|
||||
self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag()
|
||||
|
||||
@contextmanager
|
||||
def temp_config_mutation(self) -> Iterator[Self]:
|
||||
"""Returns the builder instance where the configs can be modified without affecting the current instance."""
|
||||
original_sampling_config = copy(self.sampling_config)
|
||||
original_experiment_config = copy(self.experiment_config)
|
||||
yield self
|
||||
self.sampling_config = original_sampling_config
|
||||
self.experiment_config = original_experiment_config
|
||||
def copy(self) -> Self:
|
||||
return deepcopy(self)
|
||||
|
||||
@property
|
||||
def experiment_config(self) -> ExperimentConfig:
|
||||
@ -467,18 +483,17 @@ class ExperimentBuilder:
|
||||
self._trainer_callbacks.epoch_stop_callback = callback
|
||||
return self
|
||||
|
||||
def with_experiment_name(
|
||||
def with_name(
|
||||
self,
|
||||
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||
name: str,
|
||||
) -> Self:
|
||||
"""Sets the name of the experiment.
|
||||
|
||||
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
|
||||
:param name: the name to use for this experiment, which, when the experiment is run,
|
||||
will determine the storage sub-folder by default
|
||||
:return: the builder
|
||||
"""
|
||||
if experiment_name == "DATETIME_TAG":
|
||||
experiment_name = datetime_tag()
|
||||
self._experiment_name = experiment_name
|
||||
self._name = name
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
@ -491,12 +506,9 @@ class ExperimentBuilder:
|
||||
else:
|
||||
return self._optim_factory
|
||||
|
||||
def build(self, add_seeding_info_to_name: bool = False) -> Experiment:
|
||||
def build(self) -> Experiment:
|
||||
"""Creates the experiment based on the options specified via this builder.
|
||||
|
||||
:param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains
|
||||
info about the training seeds. Useful for creating multiple experiments that only differ
|
||||
by seeds.
|
||||
:return: the experiment
|
||||
"""
|
||||
agent_factory = self._create_agent_factory()
|
||||
@ -504,34 +516,31 @@ class ExperimentBuilder:
|
||||
if self._policy_wrapper_factory:
|
||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||
experiment: Experiment = Experiment(
|
||||
self._config,
|
||||
self._env_factory,
|
||||
agent_factory,
|
||||
self._sampling_config,
|
||||
self._logger_factory,
|
||||
name=self._experiment_name,
|
||||
config=self._config,
|
||||
env_factory=self._env_factory,
|
||||
agent_factory=agent_factory,
|
||||
sampling_config=self._sampling_config,
|
||||
name=self._name,
|
||||
logger_factory=self._logger_factory,
|
||||
)
|
||||
if add_seeding_info_to_name:
|
||||
if not experiment.name:
|
||||
experiment.name = experiment.get_seeding_info_as_str()
|
||||
else:
|
||||
experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"
|
||||
return experiment
|
||||
|
||||
def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]:
|
||||
"""Creates a list of experiments with non-overlapping seeds, starting from the configured seed.
|
||||
def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection:
|
||||
"""Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed.
|
||||
|
||||
Each experiment will have a unique name that is created from the original experiment name and the seeds used.
|
||||
Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used.
|
||||
"""
|
||||
num_train_envs = self.sampling_config.num_train_envs
|
||||
|
||||
seeded_experiments = []
|
||||
for i in range(num_experiments):
|
||||
with self.temp_config_mutation():
|
||||
self.experiment_config.seed += i
|
||||
self.sampling_config.train_seed += i * num_train_envs
|
||||
seeded_experiments.append(self.build(add_seeding_info_to_name=True))
|
||||
return seeded_experiments
|
||||
builder = self.copy()
|
||||
builder.experiment_config.seed += i
|
||||
builder.sampling_config.train_seed += i * num_train_envs
|
||||
experiment = builder.build()
|
||||
experiment.name += f"_{experiment.get_seeding_info_as_str()}"
|
||||
seeded_experiments.append(experiment)
|
||||
return ExperimentCollection(seeded_experiments)
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user