extend hl experiment builder

This commit is contained in:
Maximilian Huettenrauch 2024-03-12 11:43:52 +01:00
parent 734119ec00
commit 5762d2c2e0

View File

@ -2,7 +2,8 @@ import os
import pickle
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from copy import copy
from dataclasses import dataclass, asdict
from pprint import pformat
from typing import Self
@ -86,6 +87,33 @@ from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__)
def shortener(input_string: str | None = None,
length: int = 1
):
"""Shorten the input string by keeping only the first `length` characters of each word.
If the input string is None or empty, return "default".
"""
if input_string is None or len(input_string) == 0:
return "default"
output_parts = []
for part in input_string.split(","):
key, value = part.split("=")
modified_key = ""
key_parts = key.split(".")
for key_part in key_parts:
for word in key_part.split("_"):
modified_key += word[:length] + "_"
modified_key = modified_key[:-1] + "."
modified_key = modified_key[:-1]
output_parts.append(f"{modified_key}={value}")
return ",".join(output_parts)
@dataclass
class ExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
@ -338,6 +366,22 @@ class ExperimentBuilder:
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
@property
def experiment_config(self):
return self._config
@experiment_config.setter
def experiment_config(self, experiment_config: ExperimentConfig):
self._config = experiment_config
@property
def sampling_config(self):
return self._sampling_config
@sampling_config.setter
def sampling_config(self, sampling_config: SamplingConfig):
self._sampling_config = sampling_config
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
"""Allows to customize the logger factory to use.
@ -442,6 +486,43 @@ class ExperimentBuilder:
)
return experiment
def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]:
"""Creates a dict of experiments with non-overlapping seeds, starting from the configured seed.
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
"""
configured_experiment_config = copy(self.experiment_config)
configured_experiment_seed = configured_experiment_config.seed
configured_sampling_config = copy(self.sampling_config)
configured_train_seed = configured_sampling_config.train_seed
num_train_envs = configured_sampling_config.num_train_envs
seeded_experiments = {}
for i in range(num_experiments):
experiment_seed = configured_experiment_seed + i
new_experiment_config_dict = asdict(configured_experiment_config)
new_experiment_config_dict["seed"] = experiment_seed
new_train_seed = i * num_train_envs + configured_train_seed
new_sampling_config_dict = asdict(configured_sampling_config)
new_sampling_config_dict["train_seed"] = new_train_seed
self.experiment_config = ExperimentConfig(**new_experiment_config_dict)
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
exp = self.build()
full_name = ",".join([f"experiment_seed={exp.config.seed}",
f"train_seed={exp.sampling_config.train_seed}",
f"test_seed={exp.sampling_config.test_seed}"])
experiment_name = shortener(full_name, 4)
seeded_experiments[experiment_name] = exp
# restore original config
self.experiment_config = configured_experiment_config
self.sampling_config = configured_sampling_config
return seeded_experiments
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
def __init__(self, continuous_actor_type: ContinuousActorType):