extend hl experiment builder
This commit is contained in:
parent
734119ec00
commit
5762d2c2e0
@ -2,7 +2,8 @@ import os
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, asdict
|
||||
from pprint import pformat
|
||||
from typing import Self
|
||||
|
||||
@ -86,6 +87,33 @@ from tianshou.utils.string import ToStringMixin
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def shortener(input_string: str | None = None,
|
||||
length: int = 1
|
||||
):
|
||||
"""Shorten the input string by keeping only the first `length` characters of each word.
|
||||
|
||||
If the input string is None or empty, return "default".
|
||||
"""
|
||||
if input_string is None or len(input_string) == 0:
|
||||
return "default"
|
||||
output_parts = []
|
||||
|
||||
for part in input_string.split(","):
|
||||
key, value = part.split("=")
|
||||
modified_key = ""
|
||||
|
||||
key_parts = key.split(".")
|
||||
for key_part in key_parts:
|
||||
for word in key_part.split("_"):
|
||||
modified_key += word[:length] + "_"
|
||||
modified_key = modified_key[:-1] + "."
|
||||
modified_key = modified_key[:-1]
|
||||
|
||||
output_parts.append(f"{modified_key}={value}")
|
||||
|
||||
return ",".join(output_parts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentConfig:
|
||||
"""Generic config for setting up the experiment, not RL or training specific."""
|
||||
@ -338,6 +366,22 @@ class ExperimentBuilder:
|
||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
|
||||
@property
|
||||
def experiment_config(self):
|
||||
return self._config
|
||||
|
||||
@experiment_config.setter
|
||||
def experiment_config(self, experiment_config: ExperimentConfig):
|
||||
self._config = experiment_config
|
||||
|
||||
@property
|
||||
def sampling_config(self):
|
||||
return self._sampling_config
|
||||
|
||||
@sampling_config.setter
|
||||
def sampling_config(self, sampling_config: SamplingConfig):
|
||||
self._sampling_config = sampling_config
|
||||
|
||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||
"""Allows to customize the logger factory to use.
|
||||
|
||||
@ -442,6 +486,43 @@ class ExperimentBuilder:
|
||||
)
|
||||
return experiment
|
||||
|
||||
def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]:
|
||||
"""Creates a dict of experiments with non-overlapping seeds, starting from the configured seed.
|
||||
|
||||
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
|
||||
"""
|
||||
|
||||
configured_experiment_config = copy(self.experiment_config)
|
||||
configured_experiment_seed = configured_experiment_config.seed
|
||||
configured_sampling_config = copy(self.sampling_config)
|
||||
configured_train_seed = configured_sampling_config.train_seed
|
||||
num_train_envs = configured_sampling_config.num_train_envs
|
||||
|
||||
seeded_experiments = {}
|
||||
for i in range(num_experiments):
|
||||
experiment_seed = configured_experiment_seed + i
|
||||
new_experiment_config_dict = asdict(configured_experiment_config)
|
||||
new_experiment_config_dict["seed"] = experiment_seed
|
||||
|
||||
new_train_seed = i * num_train_envs + configured_train_seed
|
||||
new_sampling_config_dict = asdict(configured_sampling_config)
|
||||
new_sampling_config_dict["train_seed"] = new_train_seed
|
||||
|
||||
self.experiment_config = ExperimentConfig(**new_experiment_config_dict)
|
||||
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
|
||||
exp = self.build()
|
||||
|
||||
full_name = ",".join([f"experiment_seed={exp.config.seed}",
|
||||
f"train_seed={exp.sampling_config.train_seed}",
|
||||
f"test_seed={exp.sampling_config.test_seed}"])
|
||||
experiment_name = shortener(full_name, 4)
|
||||
seeded_experiments[experiment_name] = exp
|
||||
|
||||
# restore original config
|
||||
self.experiment_config = configured_experiment_config
|
||||
self.sampling_config = configured_sampling_config
|
||||
return seeded_experiments
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||
def __init__(self, continuous_actor_type: ContinuousActorType):
|
||||
|
Loading…
x
Reference in New Issue
Block a user