Improve creation of multiple seeded experiments:

* Add class ExperimentCollection to improve usability
  * Remove parameters from ExperimentBuilder.build
  * Renamed ExperimentBuilder.build_default_seeded_experiments to build_seeded_collection,
    changing the return type to ExperimentCollection
  * Replace temp_config_mutation (which was not appropriate for the public API) with
    method copy (which performs a safe deep copy)
This commit is contained in:
Dominik Jain 2024-04-30 17:22:11 +02:00 committed by Michael Panchenko
parent 35779696ee
commit 024b80e79c
3 changed files with 83 additions and 171 deletions

View File

@ -14,8 +14,6 @@ These plots are saved in the log directory and displayed in the console.
import os import os
import sys import sys
from collections.abc import Sequence
from typing import Literal
import torch import torch
@ -41,86 +39,30 @@ log = logging.getLogger(__name__)
def main( def main(
experiment_config: ExperimentConfig, num_experiments: int = 2,
task: str = "Ant-v4",
num_experiments: int = 5,
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 3,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
run_experiments_sequentially: bool = True, run_experiments_sequentially: bool = True,
) -> str: ) -> RLiableExperimentResult:
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for """:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
the results are evaluated using the rliable API.
:param experiment_config:
:param task: a mujoco task name
:param num_experiments: how many experiments to run with different seeds
:param buffer_size:
:param hidden_sizes:
:param lr:
:param gamma:
:param epoch:
:param step_per_epoch:
:param step_per_collect:
:param repeat_per_collect:
:param batch_size:
:param training_num:
:param test_num:
:param rew_norm:
:param vf_coef:
:param ent_coef:
:param gae_lambda:
:param bound_action_method:
:param lr_decay:
:param max_grad_norm:
:param eps_clip:
:param dual_clip:
:param value_clip:
:param norm_adv:
:param recompute_adv:
:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
LIMITATIONS: currently, the parallel execution does not seem to work properly on linux. LIMITATIONS: currently, the parallel execution does not seem to work properly on linux.
It might generally be undesired to run multiple experiments in parallel on the same machine, It might generally be undesired to run multiple experiments in parallel on the same machine,
as a single experiment already uses all available CPU cores by default. as a single experiment already uses all available CPU cores by default.
:return: the directory where the results are stored :return: the directory where the results are stored
""" """
task = "Ant-v4"
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))
experiment_config.persistence_base_dir = persistence_dir experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False)
log.info(f"Will save all experiment results to {persistence_dir}.")
experiment_config.watch = False
sampling_config = SamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=1,
step_per_epoch=step_per_epoch, step_per_epoch=5000,
batch_size=batch_size, batch_size=64,
num_train_envs=training_num, num_train_envs=10,
num_test_envs=test_num, num_test_envs=10,
num_test_episodes=test_num, num_test_episodes=10,
buffer_size=buffer_size, buffer_size=4096,
step_per_collect=step_per_collect, step_per_collect=2048,
repeat_per_collect=repeat_per_collect, repeat_per_collect=10,
) )
env_factory = MujocoEnvFactory( env_factory = MujocoEnvFactory(
@ -133,52 +75,45 @@ def main(
else VectorEnvType.SUBPROC_SHARED_MEM, else VectorEnvType.SUBPROC_SHARED_MEM,
) )
experiments = ( hidden_sizes = (64, 64)
experiment_collection = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config) PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params( .with_ppo_params(
PPOParams( PPOParams(
discount_factor=gamma, discount_factor=0.99,
gae_lambda=gae_lambda, gae_lambda=0.95,
action_bound_method=bound_action_method, action_bound_method="clip",
reward_normalization=rew_norm, reward_normalization=True,
ent_coef=ent_coef, ent_coef=0.0,
vf_coef=vf_coef, vf_coef=0.25,
max_grad_norm=max_grad_norm, max_grad_norm=0.5,
value_clip=value_clip, value_clip=False,
advantage_normalization=norm_adv, advantage_normalization=False,
eps_clip=eps_clip, eps_clip=0.2,
dual_clip=dual_clip, dual_clip=None,
recompute_advantage=recompute_adv, recompute_advantage=True,
lr=lr, lr=3e-4,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config),
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(), dist_fn=DistributionFunctionFactoryIndependentGaussians(),
), ),
) )
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.with_logger_factory(LoggerFactoryDefault("tensorboard")) .with_logger_factory(LoggerFactoryDefault("tensorboard"))
.build_default_seeded_experiments(num_experiments) .build_seeded_collection(num_experiments)
) )
if run_experiments_sequentially: if run_experiments_sequentially:
launcher = RegisteredExpLauncher.sequential.create_launcher() launcher = RegisteredExpLauncher.sequential.create_launcher()
else: else:
launcher = RegisteredExpLauncher.joblib.create_launcher() launcher = RegisteredExpLauncher.joblib.create_launcher()
launcher.launch(experiments) experiment_collection.run(launcher)
return persistence_dir rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
def eval_experiments(log_dir: str) -> RLiableExperimentResult:
"""Evaluate the experiments in the given log directory using the rliable API."""
rliable_result = RLiableExperimentResult.load_from_disk(log_dir)
rliable_result.eval_results(show_plots=True, save_plots=True) rliable_result.eval_results(show_plots=True, save_plots=True)
return rliable_result return rliable_result
if __name__ == "__main__": if __name__ == "__main__":
log_dir = logging.run_cli(main, level=logging.INFO) result = logging.run_cli(main, level=logging.INFO)
assert isinstance(log_dir, str) # for mypy
evaluation_result = eval_experiments(log_dir)

View File

@ -79,30 +79,3 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment
experiment = builder.build() experiment = builder.build()
experiment.run(run_name="test") experiment.run(run_name="test")
print(experiment) print(experiment)
def test_temp_builder_modification() -> None:
env_factory = DiscreteTestEnvFactory()
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
builder = PPOExperimentBuilder(
experiment_config=ExperimentConfig(persistence_enabled=False),
env_factory=env_factory,
sampling_config=sampling_config,
)
original_seed = builder.experiment_config.seed
original_train_seed = builder.sampling_config.train_seed
with builder.temp_config_mutation():
builder.experiment_config.seed += 12345
builder.sampling_config.train_seed += 456
exp = builder.build()
assert builder.experiment_config.seed == original_seed
assert builder.sampling_config.train_seed == original_train_seed
assert exp.config.seed == original_seed + 12345
assert exp.sampling_config.train_seed == original_train_seed + 456

View File

@ -1,12 +1,11 @@
import os import os
import pickle import pickle
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterator, Sequence from collections.abc import Sequence
from contextlib import contextmanager from copy import deepcopy
from copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pformat from pprint import pformat
from typing import Self, Dict, Any from typing import TYPE_CHECKING, Any, Self, Union, cast
import numpy as np import numpy as np
import torch import torch
@ -85,6 +84,10 @@ from tianshou.utils.logging import datetime_tag
from tianshou.utils.net.common import ModuleType from tianshou.utils.net.common import ModuleType
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
if TYPE_CHECKING:
from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -157,19 +160,6 @@ class Experiment(ToStringMixin):
self.logger_factory = logger_factory self.logger_factory = logger_factory
self.name = name self.name = name
def get_seeding_info_as_str(self) -> str:
"""Useful for creating unique experiment names based on seeds.
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
"""
return "_".join(
[
f"exp_seed={self.config.seed}",
f"train_seed={self.sampling_config.train_seed}",
f"test_seed={self.sampling_config.test_seed}",
],
)
@classmethod @classmethod
def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment":
"""Restores an experiment from a previously stored pickle. """Restores an experiment from a previously stored pickle.
@ -184,6 +174,20 @@ class Experiment(ToStringMixin):
experiment.config.policy_restore_directory = directory experiment.config.policy_restore_directory = directory
return experiment return experiment
def get_seeding_info_as_str(self) -> str:
"""Returns information on the seeds used in the experiment as a string.
This can be useful for creating unique experiment names based on seeds, e.g.
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
"""
return "_".join(
[
f"exp_seed={self.config.seed}",
f"train_seed={self.sampling_config.train_seed}",
f"test_seed={self.sampling_config.test_seed}",
],
)
def _set_seed(self) -> None: def _set_seed(self) -> None:
seed = self.config.seed seed = self.config.seed
log.info(f"Setting random seed {seed}") log.info(f"Setting random seed {seed}")
@ -206,7 +210,7 @@ class Experiment(ToStringMixin):
run_name: str | None = None, run_name: str | None = None,
logger_run_id: str | None = None, logger_run_id: str | None = None,
raise_error_on_dirname_collision: bool = True, raise_error_on_dirname_collision: bool = True,
**kwargs: Dict[str, Any], **kwargs: dict[str, Any],
) -> ExperimentResult: ) -> ExperimentResult:
"""Run the experiment and return the results. """Run the experiment and return the results.
@ -225,7 +229,7 @@ class Experiment(ToStringMixin):
# backward compatibility # backward compatibility
_experiment_name = kwargs.pop("experiment_name", None) _experiment_name = kwargs.pop("experiment_name", None)
if _experiment_name is not None: if _experiment_name is not None:
run_name = _experiment_name run_name = cast(str, _experiment_name)
deprecation( deprecation(
"Parameter run_name should now be used instead of experiment_name. " "Parameter run_name should now be used instead of experiment_name. "
"Support for experiment_name will be removed in the future.", "Support for experiment_name will be removed in the future.",
@ -351,6 +355,18 @@ class Experiment(ToStringMixin):
) )
class ExperimentCollection:
def __init__(self, experiments: list[Experiment]):
self.experiments = experiments
def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None:
from tianshou.evaluation.launcher import RegisteredExpLauncher
if isinstance(launcher, RegisteredExpLauncher):
launcher = launcher.create_launcher()
launcher.launch(experiments=self.experiments)
class ExperimentBuilder: class ExperimentBuilder:
def __init__( def __init__(
self, self,
@ -371,14 +387,8 @@ class ExperimentBuilder:
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag()
@contextmanager def copy(self) -> Self:
def temp_config_mutation(self) -> Iterator[Self]: return deepcopy(self)
"""Returns the builder instance where the configs can be modified without affecting the current instance."""
original_sampling_config = copy(self.sampling_config)
original_experiment_config = copy(self.experiment_config)
yield self
self.sampling_config = original_sampling_config
self.experiment_config = original_experiment_config
@property @property
def experiment_config(self) -> ExperimentConfig: def experiment_config(self) -> ExperimentConfig:
@ -495,12 +505,9 @@ class ExperimentBuilder:
else: else:
return self._optim_factory return self._optim_factory
def build(self, add_seeding_info_to_name: bool = False) -> Experiment: def build(self) -> Experiment:
"""Creates the experiment based on the options specified via this builder. """Creates the experiment based on the options specified via this builder.
:param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains
info about the training seeds. Useful for creating multiple experiments that only differ
by seeds.
:return: the experiment :return: the experiment
""" """
agent_factory = self._create_agent_factory() agent_factory = self._create_agent_factory()
@ -515,27 +522,24 @@ class ExperimentBuilder:
name=self._name, name=self._name,
logger_factory=self._logger_factory, logger_factory=self._logger_factory,
) )
if add_seeding_info_to_name:
if not experiment.name:
experiment.name = experiment.get_seeding_info_as_str()
else:
experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"
return experiment return experiment
def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]: def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection:
"""Creates a list of experiments with non-overlapping seeds, starting from the configured seed. """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed.
Each experiment will have a unique name that is created from the original experiment name and the seeds used. Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used.
""" """
num_train_envs = self.sampling_config.num_train_envs num_train_envs = self.sampling_config.num_train_envs
seeded_experiments = [] seeded_experiments = []
for i in range(num_experiments): for i in range(num_experiments):
with self.temp_config_mutation(): builder = self.copy()
self.experiment_config.seed += i builder.experiment_config.seed += i
self.sampling_config.train_seed += i * num_train_envs builder.sampling_config.train_seed += i * num_train_envs
seeded_experiments.append(self.build(add_seeding_info_to_name=True)) experiment = builder.build()
return seeded_experiments experiment.name += f"_{experiment.get_seeding_info_as_str()}"
seeded_experiments.append(experiment)
return ExperimentCollection(seeded_experiments)
class _BuilderMixinActorFactory(ActorFutureProviderProtocol): class _BuilderMixinActorFactory(ActorFutureProviderProtocol):