extend hl experiment builder
This commit is contained in:
parent
734119ec00
commit
5762d2c2e0
@ -2,7 +2,8 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from copy import copy
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
@ -86,6 +87,33 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def shortener(input_string: str | None = None,
|
||||||
|
length: int = 1
|
||||||
|
):
|
||||||
|
"""Shorten the input string by keeping only the first `length` characters of each word.
|
||||||
|
|
||||||
|
If the input string is None or empty, return "default".
|
||||||
|
"""
|
||||||
|
if input_string is None or len(input_string) == 0:
|
||||||
|
return "default"
|
||||||
|
output_parts = []
|
||||||
|
|
||||||
|
for part in input_string.split(","):
|
||||||
|
key, value = part.split("=")
|
||||||
|
modified_key = ""
|
||||||
|
|
||||||
|
key_parts = key.split(".")
|
||||||
|
for key_part in key_parts:
|
||||||
|
for word in key_part.split("_"):
|
||||||
|
modified_key += word[:length] + "_"
|
||||||
|
modified_key = modified_key[:-1] + "."
|
||||||
|
modified_key = modified_key[:-1]
|
||||||
|
|
||||||
|
output_parts.append(f"{modified_key}={value}")
|
||||||
|
|
||||||
|
return ",".join(output_parts)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExperimentConfig:
|
class ExperimentConfig:
|
||||||
"""Generic config for setting up the experiment, not RL or training specific."""
|
"""Generic config for setting up the experiment, not RL or training specific."""
|
||||||
@ -338,6 +366,22 @@ class ExperimentBuilder:
|
|||||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment_config(self):
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
@experiment_config.setter
|
||||||
|
def experiment_config(self, experiment_config: ExperimentConfig):
|
||||||
|
self._config = experiment_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampling_config(self):
|
||||||
|
return self._sampling_config
|
||||||
|
|
||||||
|
@sampling_config.setter
|
||||||
|
def sampling_config(self, sampling_config: SamplingConfig):
|
||||||
|
self._sampling_config = sampling_config
|
||||||
|
|
||||||
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
def with_logger_factory(self, logger_factory: LoggerFactory) -> Self:
|
||||||
"""Allows to customize the logger factory to use.
|
"""Allows to customize the logger factory to use.
|
||||||
|
|
||||||
@ -442,6 +486,43 @@ class ExperimentBuilder:
|
|||||||
)
|
)
|
||||||
return experiment
|
return experiment
|
||||||
|
|
||||||
|
def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]:
|
||||||
|
"""Creates a dict of experiments with non-overlapping seeds, starting from the configured seed.
|
||||||
|
|
||||||
|
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
configured_experiment_config = copy(self.experiment_config)
|
||||||
|
configured_experiment_seed = configured_experiment_config.seed
|
||||||
|
configured_sampling_config = copy(self.sampling_config)
|
||||||
|
configured_train_seed = configured_sampling_config.train_seed
|
||||||
|
num_train_envs = configured_sampling_config.num_train_envs
|
||||||
|
|
||||||
|
seeded_experiments = {}
|
||||||
|
for i in range(num_experiments):
|
||||||
|
experiment_seed = configured_experiment_seed + i
|
||||||
|
new_experiment_config_dict = asdict(configured_experiment_config)
|
||||||
|
new_experiment_config_dict["seed"] = experiment_seed
|
||||||
|
|
||||||
|
new_train_seed = i * num_train_envs + configured_train_seed
|
||||||
|
new_sampling_config_dict = asdict(configured_sampling_config)
|
||||||
|
new_sampling_config_dict["train_seed"] = new_train_seed
|
||||||
|
|
||||||
|
self.experiment_config = ExperimentConfig(**new_experiment_config_dict)
|
||||||
|
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
|
||||||
|
exp = self.build()
|
||||||
|
|
||||||
|
full_name = ",".join([f"experiment_seed={exp.config.seed}",
|
||||||
|
f"train_seed={exp.sampling_config.train_seed}",
|
||||||
|
f"test_seed={exp.sampling_config.test_seed}"])
|
||||||
|
experiment_name = shortener(full_name, 4)
|
||||||
|
seeded_experiments[experiment_name] = exp
|
||||||
|
|
||||||
|
# restore original config
|
||||||
|
self.experiment_config = configured_experiment_config
|
||||||
|
self.sampling_config = configured_sampling_config
|
||||||
|
return seeded_experiments
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
class _BuilderMixinActorFactory(ActorFutureProviderProtocol):
|
||||||
def __init__(self, continuous_actor_type: ContinuousActorType):
|
def __init__(self, continuous_actor_type: ContinuousActorType):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user