Added contextmanager for ExperimentBuilder modifications
Used it to simplify building of seeded experiments
This commit is contained in:
parent
85e910ec5d
commit
ed12b16d70
@ -79,3 +79,30 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
|
|||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
experiment.run("test")
|
experiment.run("test")
|
||||||
print(experiment)
|
print(experiment)
|
||||||
|
|
||||||
|
|
||||||
|
def test_temp_builder_modification():
|
||||||
|
env_factory = DiscreteTestEnvFactory()
|
||||||
|
sampling_config = SamplingConfig(
|
||||||
|
num_epochs=1,
|
||||||
|
step_per_epoch=100,
|
||||||
|
num_train_envs=2,
|
||||||
|
num_test_envs=2,
|
||||||
|
)
|
||||||
|
builder = PPOExperimentBuilder(
|
||||||
|
experiment_config=ExperimentConfig(persistence_enabled=False),
|
||||||
|
env_factory=env_factory,
|
||||||
|
sampling_config=sampling_config,
|
||||||
|
)
|
||||||
|
original_seed = builder.experiment_config.seed
|
||||||
|
original_train_seed = builder.sampling_config.train_seed
|
||||||
|
|
||||||
|
with builder.temp_config_mutation():
|
||||||
|
builder.experiment_config.seed += 12345
|
||||||
|
builder.sampling_config.train_seed += 456
|
||||||
|
exp = builder.build()
|
||||||
|
|
||||||
|
assert builder.experiment_config.seed == original_seed
|
||||||
|
assert builder.sampling_config.train_seed == original_train_seed
|
||||||
|
assert exp.config.seed == original_seed + 12345
|
||||||
|
assert exp.sampling_config.train_seed == original_train_seed + 456
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Iterator, Sequence
|
||||||
|
from contextlib import contextmanager
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import dataclass
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Self
|
from typing import Literal, Self
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -145,6 +146,7 @@ class Experiment(ToStringMixin):
|
|||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
sampling_config: SamplingConfig,
|
sampling_config: SamplingConfig,
|
||||||
logger_factory: LoggerFactory | None = None,
|
logger_factory: LoggerFactory | None = None,
|
||||||
|
name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||||
):
|
):
|
||||||
if logger_factory is None:
|
if logger_factory is None:
|
||||||
logger_factory = LoggerFactoryDefault()
|
logger_factory = LoggerFactoryDefault()
|
||||||
@ -153,6 +155,22 @@ class Experiment(ToStringMixin):
|
|||||||
self.env_factory = env_factory
|
self.env_factory = env_factory
|
||||||
self.agent_factory = agent_factory
|
self.agent_factory = agent_factory
|
||||||
self.logger_factory = logger_factory
|
self.logger_factory = logger_factory
|
||||||
|
if name == "DATETIME_TAG":
|
||||||
|
name = datetime_tag()
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def get_seeding_info_as_str(self) -> str:
|
||||||
|
"""Useful for creating unique experiment names based on seeds.
|
||||||
|
|
||||||
|
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
|
||||||
|
"""
|
||||||
|
return ",".join(
|
||||||
|
[
|
||||||
|
f"exp_seed={self.config.seed}",
|
||||||
|
f"train_seed={self.sampling_config.train_seed}",
|
||||||
|
f"test_seed={self.sampling_config.test_seed}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
|
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
|
||||||
@ -338,6 +356,16 @@ class ExperimentBuilder:
|
|||||||
self._optim_factory: OptimizerFactory | None = None
|
self._optim_factory: OptimizerFactory | None = None
|
||||||
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
||||||
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
||||||
|
self._experiment_name: str = ""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temp_config_mutation(self) -> Iterator[Self]:
|
||||||
|
"""Returns the builder instance where the configs can be modified without affecting the current instance."""
|
||||||
|
original_sampling_config = copy(self.sampling_config)
|
||||||
|
original_experiment_config = copy(self.experiment_config)
|
||||||
|
yield self
|
||||||
|
self.sampling_config = original_sampling_config
|
||||||
|
self.experiment_config = original_experiment_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experiment_config(self) -> ExperimentConfig:
|
def experiment_config(self) -> ExperimentConfig:
|
||||||
@ -431,6 +459,20 @@ class ExperimentBuilder:
|
|||||||
self._trainer_callbacks.epoch_stop_callback = callback
|
self._trainer_callbacks.epoch_stop_callback = callback
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_experiment_name(
|
||||||
|
self,
|
||||||
|
experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG",
|
||||||
|
) -> Self:
|
||||||
|
"""Sets the name of the experiment.
|
||||||
|
|
||||||
|
:param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used.
|
||||||
|
:return: the builder
|
||||||
|
"""
|
||||||
|
if experiment_name == "DATETIME_TAG":
|
||||||
|
experiment_name = datetime_tag()
|
||||||
|
self._experiment_name = experiment_name
|
||||||
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _create_agent_factory(self) -> AgentFactory:
|
def _create_agent_factory(self) -> AgentFactory:
|
||||||
pass
|
pass
|
||||||
@ -441,9 +483,12 @@ class ExperimentBuilder:
|
|||||||
else:
|
else:
|
||||||
return self._optim_factory
|
return self._optim_factory
|
||||||
|
|
||||||
def build(self) -> Experiment:
|
def build(self, add_seeding_info_to_name: bool = False) -> Experiment:
|
||||||
"""Creates the experiment based on the options specified via this builder.
|
"""Creates the experiment based on the options specified via this builder.
|
||||||
|
|
||||||
|
:param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains
|
||||||
|
info about the training seeds. Useful for creating multiple experiments that only differ
|
||||||
|
by seeds.
|
||||||
:return: the experiment
|
:return: the experiment
|
||||||
"""
|
"""
|
||||||
agent_factory = self._create_agent_factory()
|
agent_factory = self._create_agent_factory()
|
||||||
@ -456,46 +501,25 @@ class ExperimentBuilder:
|
|||||||
agent_factory,
|
agent_factory,
|
||||||
self._sampling_config,
|
self._sampling_config,
|
||||||
self._logger_factory,
|
self._logger_factory,
|
||||||
|
name=self._experiment_name,
|
||||||
)
|
)
|
||||||
|
if add_seeding_info_to_name:
|
||||||
|
experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"
|
||||||
return experiment
|
return experiment
|
||||||
|
|
||||||
def build_default_seeded_experiments(self, num_experiments: int) -> dict[str, Experiment]:
|
def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]:
|
||||||
"""Creates a dict of experiments with non-overlapping seeds, starting from the configured seed.
|
"""Creates a list of experiments with non-overlapping seeds, starting from the configured seed.
|
||||||
|
|
||||||
The keys of the dict are the experiment names, which are derived from the seeds used in the experiments.
|
Each experiment will have a unique name that is created from the original experiment name and the seeds used.
|
||||||
"""
|
"""
|
||||||
configured_experiment_config = copy(self.experiment_config)
|
num_train_envs = self.sampling_config.num_train_envs
|
||||||
configured_experiment_seed = configured_experiment_config.seed
|
|
||||||
configured_sampling_config = copy(self.sampling_config)
|
|
||||||
configured_train_seed = configured_sampling_config.train_seed
|
|
||||||
num_train_envs = configured_sampling_config.num_train_envs
|
|
||||||
|
|
||||||
seeded_experiments = {}
|
seeded_experiments = []
|
||||||
for i in range(num_experiments):
|
for i in range(num_experiments):
|
||||||
experiment_seed = configured_experiment_seed + i
|
with self.temp_config_mutation():
|
||||||
new_experiment_config_dict = asdict(configured_experiment_config)
|
self.experiment_config.seed += i
|
||||||
new_experiment_config_dict["seed"] = experiment_seed
|
self.sampling_config.train_seed += i * num_train_envs
|
||||||
|
seeded_experiments.append(self.build(add_seeding_info_to_name=True))
|
||||||
new_train_seed = i * num_train_envs + configured_train_seed
|
|
||||||
new_sampling_config_dict = asdict(configured_sampling_config)
|
|
||||||
new_sampling_config_dict["train_seed"] = new_train_seed
|
|
||||||
|
|
||||||
self.experiment_config = ExperimentConfig(**new_experiment_config_dict)
|
|
||||||
self.sampling_config = SamplingConfig(**new_sampling_config_dict)
|
|
||||||
exp = self.build()
|
|
||||||
|
|
||||||
experiment_name = ",".join(
|
|
||||||
[
|
|
||||||
f"exp_seed={exp.config.seed}",
|
|
||||||
f"train_seed={exp.sampling_config.train_seed}",
|
|
||||||
f"test_seed={exp.sampling_config.test_seed}",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
seeded_experiments[experiment_name] = exp
|
|
||||||
|
|
||||||
# restore original config
|
|
||||||
self.experiment_config = configured_experiment_config
|
|
||||||
self.sampling_config = configured_sampling_config
|
|
||||||
return seeded_experiments
|
return seeded_experiments
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user