Added contextmanager for ExperimentBuilder modifications
Used it to simplify building of seeded experiments
This commit is contained in:
parent
85e910ec5d
commit
ed12b16d70
@ -79,3 +79,30 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
|
||||
experiment = builder.build()
|
||||
experiment.run("test")
|
||||
print(experiment)
|
||||
|
||||
|
||||
def test_temp_builder_modification():
|
||||
env_factory = DiscreteTestEnvFactory()
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=1,
|
||||
step_per_epoch=100,
|
||||
num_train_envs=2,
|
||||
num_test_envs=2,
|
||||
)
|
||||
builder = PPOExperimentBuilder(
|
||||
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||
env_factory=env_factory,
|
||||
sampling_config=sampling_config,
|
||||
)
|
||||
original_seed = builder.experiment_config.seed
|
||||
original_train_seed = builder.sampling_config.train_seed
|
||||
|
||||
with builder.temp_config_mutation():
|
||||
builder.experiment_config.seed += 12345
|
||||
builder.sampling_config.train_seed += 456
|
||||
exp = builder.build()
|
||||
|
||||
assert builder.experiment_config.seed == original_seed
|
||||
assert builder.sampling_config.train_seed == original_train_seed
|
||||
assert exp.config.seed == original_seed + 12345
|
||||
assert exp.sampling_config.train_seed == original_train_seed + 456
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Self
|
||||
from typing import Literal, Self
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -145,6 +146,7 @@ class Experiment(ToStringMixin):
|
||||
agent_factory: AgentFactory,
|
||||
sampling_config: SamplingConfig,
|
||||
logger_factory: LoggerFactory | None = None,
|
||||
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||
):
|
||||
if logger_factory is None:
|
||||
logger_factory = LoggerFactoryDefault()
|
||||
@ -153,6 +155,22 @@ class Experiment(ToStringMixin):
|
||||
self.env_factory = env_factory
|
||||
self.agent_factory = agent_factory
|
||||
self.logger_factory = logger_factory
|
||||
if name == "DATETIME_TAG":
|
||||
name = datetime_tag()
|
||||
self.name = name
|
||||
|
||||
def get_seeding_info_as_str(self) -> str:
|
||||
"""Useful for creating unique experiment names based on seeds.
|
||||
|
||||
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
|
||||
"""
|
||||
return ",".join(
|
||||
[
|
||||
f"exp_seed={self.config.seed}",
|
||||
f"train_seed={self.sampling_config.train_seed}",
|
||||
f"test_seed={self.sampling_config.test_seed}",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
|
||||
@ -338,6 +356,16 @@ class ExperimentBuilder:
|
||||
self._optim_factory: OptimizerFactory | None = None
|
||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||
self._experiment_name: str = ""
|
||||
|
||||
@contextmanager
|
||||
def temp_config_mutation(self) -> Iterator[Self]:
|
||||
"""Returns the builder instance where the configs can be modified without affecting the current instance."""
|
||||
original_sampling_config = copy(self.sampling_config)
|
||||
original_experiment_config = copy(self.experiment_config)
|
||||
yield self
|
||||
self.sampling_config = original_sampling_config
|
||||
self.experiment_config = original_experiment_config
|
||||
|
||||
@property
|
||||
def experiment_config(self) -> ExperimentConfig:
|
||||
@ -431,6 +459,20 @@ class ExperimentBuilder:
|
||||
self._trainer_callbacks.epoch_stop_callback = callback
|
||||
return self
|
||||
|
||||
def with_experiment_name(
|
||||
self,
|
||||
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||
) -> Self:
|
||||
"""Sets the name of the experiment.
|
||||
|
||||
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
|
||||
:return: the builder
|
||||
"""
|
||||
if experiment_name == "DATETIME_TAG":
|
||||
experiment_name = datetime_tag()
|
||||
self._experiment_name = experiment_name
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
pass
|
||||
@ -441,9 +483,12 @@ class ExperimentBuilder:
|
||||
else:
|
||||
return self._optim_factory
|
||||
|
||||
def build(self) -> Experiment:
|
||||
def build(self, add_seeding_info_to_name: bool = False) -> Experiment:
|
||||
"""Creates the experiment based on the options specified via this builder.
|
||||
|
||||
:param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains
|
||||
info about the training seeds. Useful for creating multiple experiments that only differ
|
||||
by seeds.
|
||||
:return: the experiment
|
||||
"""
|
||||
agent_factory = self._create_agent_factory()
|
||||
@ -456,46 +501,25 @@ class ExperimentBuilder:
|
||||
agent_factory,
|
||||
self._sampling_config,
|
||||
self._logger_factory,
|
||||
name=self._experiment_name,
|
||||
)
|
||||
if add_seeding_info_to_name:
|
||||
experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"
|
||||
return experiment
|
||||
|
||||
def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]:
|
||||
"""Creates a dict of experiments with non-overlapping seeds, starting from the configured seed.
|
||||
def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]:
|
||||
"""Creates a list of experiments with non-overlapping seeds, starting from the configured seed.
|
||||
|
||||
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
|
||||
Each experiment will have a unique name that is created from the original experiment name and the seeds used.
|
||||
"""
|
||||
configured_experiment_config = copy(self.experiment_config)
|
||||
configured_experiment_seed = configured_experiment_config.seed
|
||||
configured_sampling_config = copy(self.sampling_config)
|
||||
configured_train_seed = configured_sampling_config.train_seed
|
||||
num_train_envs = configured_sampling_config.num_train_envs
|
||||
num_train_envs = self.sampling_config.num_train_envs
|
||||
|
||||
seeded_experiments = {}
|
||||
seeded_experiments = []
|
||||
for i in range(num_experiments):
|
||||
experiment_seed = configured_experiment_seed + i
|
||||
new_experiment_config_dict = asdict(configured_experiment_config)
|
||||
new_experiment_config_dict["seed"] = experiment_seed
|
||||
|
||||
new_train_seed = i * num_train_envs + configured_train_seed
|
||||
new_sampling_config_dict = asdict(configured_sampling_config)
|
||||
new_sampling_config_dict["train_seed"] = new_train_seed
|
||||
|
||||
self.experiment_config = ExperimentConfig(**new_experiment_config_dict)
|
||||
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
|
||||
exp = self.build()
|
||||
|
||||
experiment_name = ",".join(
|
||||
[
|
||||
f"exp_seed={exp.config.seed}",
|
||||
f"train_seed={exp.sampling_config.train_seed}",
|
||||
f"test_seed={exp.sampling_config.test_seed}",
|
||||
],
|
||||
)
|
||||
seeded_experiments[experiment_name] = exp
|
||||
|
||||
# restore original config
|
||||
self.experiment_config = configured_experiment_config
|
||||
self.sampling_config = configured_sampling_config
|
||||
with self.temp_config_mutation():
|
||||
self.experiment_config.seed += i
|
||||
self.sampling_config.train_seed += i * num_train_envs
|
||||
seeded_experiments.append(self.build(add_seeding_info_to_name=True))
|
||||
return seeded_experiments
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user