Added contextmanager for ExperimentBuilder modifications

Used it to simplify building of seeded experiments
This commit is contained in:
Michael Panchenko 2024-04-03 17:28:38 +02:00
parent 85e910ec5d
commit ed12b16d70
2 changed files with 88 additions and 37 deletions

View File

@ -79,3 +79,30 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
experiment = builder.build()
experiment.run("test")
print(experiment)
def test_temp_builder_modification():
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = PPOExperimentBuilder(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)
original_seed = builder.experiment_config.seed
original_train_seed = builder.sampling_config.train_seed
with builder.temp_config_mutation():
builder.experiment_config.seed += 12345
builder.sampling_config.train_seed += 456
exp = builder.build()
assert builder.experiment_config.seed == original_seed
assert builder.sampling_config.train_seed == original_train_seed
assert exp.config.seed == original_seed + 12345
assert exp.sampling_config.train_seed == original_train_seed + 456

View File

@ -1,11 +1,12 @@
import os
import pickle
from abc import abstractmethod
from collections.abc import Sequence
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from copy import copy
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from pprint import pformat
from typing import Self
from typing import Literal, Self
import numpy as np
import torch
@ -145,6 +146,7 @@ class Experiment(ToStringMixin):
agent_factory: AgentFactory,
sampling_config: SamplingConfig,
logger_factory: LoggerFactory | None = None,
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
):
if logger_factory is None:
logger_factory = LoggerFactoryDefault()
@ -153,6 +155,22 @@ class Experiment(ToStringMixin):
self.env_factory = env_factory
self.agent_factory = agent_factory
self.logger_factory = logger_factory
if name == "DATETIME_TAG":
name = datetime_tag()
self.name = name
def get_seeding_info_as_str(self) -> str:
"""Useful for creating unique experiment names based on seeds.
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
"""
return ",".join(
[
f"exp_seed={self.config.seed}",
f"train_seed={self.sampling_config.train_seed}",
f"test_seed={self.sampling_config.test_seed}",
],
)
@classmethod
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
@ -338,6 +356,16 @@ class ExperimentBuilder:
self._optim_factory: OptimizerFactory | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
self._experiment_name: str = ""
@contextmanager
def temp_config_mutation(self) -> Iterator[Self]:
"""Returns the builder instance where the configs can be modified without affecting the current instance."""
original_sampling_config = copy(self.sampling_config)
original_experiment_config = copy(self.experiment_config)
yield self
self.sampling_config = original_sampling_config
self.experiment_config = original_experiment_config
@property
def experiment_config(self) -> ExperimentConfig:
@ -431,6 +459,20 @@ class ExperimentBuilder:
self._trainer_callbacks.epoch_stop_callback = callback
return self
def with_experiment_name(
self,
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
) -> Self:
"""Sets the name of the experiment.
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
:return: the builder
"""
if experiment_name == "DATETIME_TAG":
experiment_name = datetime_tag()
self._experiment_name = experiment_name
return self
@abstractmethod
def _create_agent_factory(self) -> AgentFactory:
pass
@ -441,9 +483,12 @@ class ExperimentBuilder:
else:
return self._optim_factory
def build(self) -> Experiment:
def build(self, add_seeding_info_to_name: bool = False) -> Experiment:
"""Creates the experiment based on the options specified via this builder.
:param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains
info about the training seeds. Useful for creating multiple experiments that only differ
by seeds.
:return: the experiment
"""
agent_factory = self._create_agent_factory()
@ -456,46 +501,25 @@ class ExperimentBuilder:
agent_factory,
self._sampling_config,
self._logger_factory,
name=self._experiment_name,
)
if add_seeding_info_to_name:
experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"
return experiment
def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]:
"""Creates a dict of experiments with non-overlapping seeds, starting from the configured seed.
def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]:
"""Creates a list of experiments with non-overlapping seeds, starting from the configured seed.
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
Each experiment will have a unique name that is created from the original experiment name and the seeds used.
"""
configured_experiment_config = copy(self.experiment_config)
configured_experiment_seed = configured_experiment_config.seed
configured_sampling_config = copy(self.sampling_config)
configured_train_seed = configured_sampling_config.train_seed
num_train_envs = configured_sampling_config.num_train_envs
num_train_envs = self.sampling_config.num_train_envs
seeded_experiments = {}
seeded_experiments = []
for i in range(num_experiments):
experiment_seed = configured_experiment_seed + i
new_experiment_config_dict = asdict(configured_experiment_config)
new_experiment_config_dict["seed"] = experiment_seed
new_train_seed = i * num_train_envs + configured_train_seed
new_sampling_config_dict = asdict(configured_sampling_config)
new_sampling_config_dict["train_seed"] = new_train_seed
self.experiment_config = ExperimentConfig(**new_experiment_config_dict)
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
exp = self.build()
experiment_name = ",".join(
[
f"exp_seed={exp.config.seed}",
f"train_seed={exp.sampling_config.train_seed}",
f"test_seed={exp.sampling_config.test_seed}",
],
)
seeded_experiments[experiment_name] = exp
# restore original config
self.experiment_config = configured_experiment_config
self.sampling_config = configured_sampling_config
with self.temp_config_mutation():
self.experiment_config.seed += i
self.sampling_config.train_seed += i * num_train_envs
seeded_experiments.append(self.build(add_seeding_info_to_name=True))
return seeded_experiments