Added contextmanager for ExperimentBuilder modifications

Used it to simplify building of seeded experiments
This commit is contained in:
Michael Panchenko 2024-04-03 17:28:38 +02:00
parent 85e910ec5d
commit ed12b16d70
2 changed files with 88 additions and 37 deletions

View File

@ -79,3 +79,30 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
experiment = builder.build() experiment = builder.build()
experiment.run("test") experiment.run("test")
print(experiment) print(experiment)
def test_temp_builder_modification():
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = PPOExperimentBuilder(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)
original_seed = builder.experiment_config.seed
original_train_seed = builder.sampling_config.train_seed
with builder.temp_config_mutation():
builder.experiment_config.seed += 12345
builder.sampling_config.train_seed += 456
exp = builder.build()
assert builder.experiment_config.seed == original_seed
assert builder.sampling_config.train_seed == original_train_seed
assert exp.config.seed == original_seed + 12345
assert exp.sampling_config.train_seed == original_train_seed + 456

View File

@ -1,11 +1,12 @@
import os import os
import pickle import pickle
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from copy import copy from copy import copy
from dataclasses import asdict, dataclass from dataclasses import dataclass
from pprint import pformat from pprint import pformat
from typing import Self from typing import Literal, Self
import numpy as np import numpy as np
import torch import torch
@ -145,6 +146,7 @@ class Experiment(ToStringMixin):
agent_factory: AgentFactory, agent_factory: AgentFactory,
sampling_config: SamplingConfig, sampling_config: SamplingConfig,
logger_factory: LoggerFactory | None = None, logger_factory: LoggerFactory | None = None,
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
): ):
if logger_factory is None: if logger_factory is None:
logger_factory = LoggerFactoryDefault() logger_factory = LoggerFactoryDefault()
@ -153,6 +155,22 @@ class Experiment(ToStringMixin):
self.env_factory = env_factory self.env_factory = env_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
self.logger_factory = logger_factory self.logger_factory = logger_factory
if name == "DATETIME_TAG":
name = datetime_tag()
self.name = name
def get_seeding_info_as_str(self) -> str:
"""Useful for creating unique experiment names based on seeds.
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
"""
return ",".join(
[
f"exp_seed={self.config.seed}",
f"train_seed={self.sampling_config.train_seed}",
f"test_seed={self.sampling_config.test_seed}",
],
)
@classmethod @classmethod
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
@ -338,6 +356,16 @@ class ExperimentBuilder:
self._optim_factory: OptimizerFactory | None = None self._optim_factory: OptimizerFactory | None = None
self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
self._experiment_name: str = ""
@contextmanager
def temp_config_mutation(self) -> Iterator[Self]:
"""Returns the builder instance where the configs can be modified without affecting the current instance."""
original_sampling_config = copy(self.sampling_config)
original_experiment_config = copy(self.experiment_config)
yield self
self.sampling_config = original_sampling_config
self.experiment_config = original_experiment_config
@property @property
def experiment_config(self) -> ExperimentConfig: def experiment_config(self) -> ExperimentConfig:
@ -431,6 +459,20 @@ class ExperimentBuilder:
self._trainer_callbacks.epoch_stop_callback = callback self._trainer_callbacks.epoch_stop_callback = callback
return self return self
def with_experiment_name(
self,
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
) -> Self:
"""Sets the name of the experiment.
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
:return: the builder
"""
if experiment_name == "DATETIME_TAG":
experiment_name = datetime_tag()
self._experiment_name = experiment_name
return self
@abstractmethod @abstractmethod
def _create_agent_factory(self) -> AgentFactory: def _create_agent_factory(self) -> AgentFactory:
pass pass
@ -441,9 +483,12 @@ class ExperimentBuilder:
else: else:
return self._optim_factory return self._optim_factory
def build(self) -> Experiment: def build(self, add_seeding_info_to_name: bool = False) -> Experiment:
"""Creates the experiment based on the options specified via this builder. """Creates the experiment based on the options specified via this builder.
:param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains
info about the training seeds. Useful for creating multiple experiments that only differ
by seeds.
:return: the experiment :return: the experiment
""" """
agent_factory = self._create_agent_factory() agent_factory = self._create_agent_factory()
@ -456,46 +501,25 @@ class ExperimentBuilder:
agent_factory, agent_factory,
self._sampling_config, self._sampling_config,
self._logger_factory, self._logger_factory,
name=self._experiment_name,
) )
if add_seeding_info_to_name:
experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"
return experiment return experiment
def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]: def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]:
"""Creates a dict of experiments with non-overlapping seeds, starting from the configured seed. """Creates a list of experiments with non-overlapping seeds, starting from the configured seed.
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments. Each experiment will have a unique name that is created from the original experiment name and the seeds used.
""" """
configured_experiment_config = copy(self.experiment_config) num_train_envs = self.sampling_config.num_train_envs
configured_experiment_seed = configured_experiment_config.seed
configured_sampling_config = copy(self.sampling_config)
configured_train_seed = configured_sampling_config.train_seed
num_train_envs = configured_sampling_config.num_train_envs
seeded_experiments = {} seeded_experiments = []
for i in range(num_experiments): for i in range(num_experiments):
experiment_seed = configured_experiment_seed + i with self.temp_config_mutation():
new_experiment_config_dict = asdict(configured_experiment_config) self.experiment_config.seed += i
new_experiment_config_dict["seed"] = experiment_seed self.sampling_config.train_seed += i * num_train_envs
seeded_experiments.append(self.build(add_seeding_info_to_name=True))
new_train_seed = i * num_train_envs + configured_train_seed
new_sampling_config_dict = asdict(configured_sampling_config)
new_sampling_config_dict["train_seed"] = new_train_seed
self.experiment_config = ExperimentConfig(**new_experiment_config_dict)
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
exp = self.build()
experiment_name = ",".join(
[
f"exp_seed={exp.config.seed}",
f"train_seed={exp.sampling_config.train_seed}",
f"test_seed={exp.sampling_config.test_seed}",
],
)
seeded_experiments[experiment_name] = exp
# restore original config
self.experiment_config = configured_experiment_config
self.sampling_config = configured_sampling_config
return seeded_experiments return seeded_experiments