Bugfix/parallel launcher for linux (#1141)

Closes: #1121
This commit is contained in:
Michael Panchenko 2024-05-08 11:02:01 +02:00 committed by GitHub
commit f0b7abe015
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 156 additions and 49 deletions

View File

@ -3,6 +3,8 @@
## Release 1.1.0 ## Release 1.1.0
### Api Extensions ### Api Extensions
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141
- The module `evaluation.launchers` for parallelization is currently in alpha state.
- `data`: - `data`:
- `Batch`: - `Batch`:
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098 - Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
@ -24,7 +26,7 @@
- `SamplingConfig`: - `SamplingConfig`:
- Add support for `batch_size=None`. #1077 - Add support for `batch_size=None`. #1077
- Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074 - Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074
- `highlevel.experiment`: - `experiment`:
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
which determines the default run name and therefore the persistence subdirectory. which determines the default run name and therefore the persistence subdirectory.
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
@ -34,8 +36,8 @@
- Add method `build_seeded_collection` for the sound creation of multiple - Add method `build_seeded_collection` for the sound creation of multiple
experiments with varying random seeds #1131 experiments with varying random seeds #1131
- Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131 - Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 - `env`:
- The module `evaluation.launchers` for parallelization is currently in alpha state. - Added new `VectorEnvType` called `SUBPROC_SHARED_MEM_AUTO` and used in for Atari and Mujoco venv creation. #1141
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- `utils`: - `utils`:
- `net.continuous.Critic`: - `net.continuous.Critic`:
@ -45,6 +47,8 @@
to reuse the actor's preprocessing network #1128 to reuse the actor's preprocessing network #1128
- `torch_utils` (new module) - `torch_utils` (new module)
- Added context managers `torch_train_mode` and `policy_within_training_step` #1123 - Added context managers `torch_train_mode` and `policy_within_training_step` #1123
- `print`
- `DataclassPPrintMixin` now supports outputting a string, not just printing the pretty repr. #1141
### Fixes ### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,

View File

@ -397,6 +397,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
frame_stack: int, frame_stack: int,
scale: bool = False, scale: bool = False,
use_envpool_if_available: bool = True, use_envpool_if_available: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None: ) -> None:
assert "NoFrameskip" in task assert "NoFrameskip" in task
self.frame_stack = frame_stack self.frame_stack = frame_stack
@ -412,7 +413,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
task=task, task=task,
train_seed=train_seed, train_seed=train_seed,
test_seed=test_seed, test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM, venv_type=venv_type,
envpool_factory=envpool_factory, envpool_factory=envpool_factory,
) )

View File

@ -76,7 +76,7 @@ class MujocoEnvFactory(EnvFactoryRegistered):
train_seed: int, train_seed: int,
test_seed: int, test_seed: int,
obs_norm: bool = True, obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM, venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None: ) -> None:
super().__init__( super().__init__(
task=task, task=task,

View File

@ -6,14 +6,13 @@ For each repetition, a policy seed, train env seeds, and test env seeds are set
are non-intersecting with the seeds of the other experiments. are non-intersecting with the seeds of the other experiments.
Each experiment's results are stored in a separate subdirectory. Each experiment's results are stored in a separate subdirectory.
The final results are aggregated and turned into useful statistics with the rliable API. The final results are aggregated and turned into useful statistics with the rliable package.
The call to `eval_experiments` will load the results from the log directory and The call to `eval_experiments` will load the results from the log directory and
create an interp-quantile mean plot for the returns as well as a performance profile plot. create an interp-quantile mean plot for the returns as well as a performance profile plot.
These plots are saved in the log directory and displayed in the console. These plots are saved in the log directory and displayed in the console.
""" """
import os import os
import sys
import torch import torch
@ -21,7 +20,6 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.evaluation.launcher import RegisteredExpLauncher from tianshou.evaluation.launcher import RegisteredExpLauncher
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
PPOExperimentBuilder, PPOExperimentBuilder,
@ -39,14 +37,14 @@ log = logging.getLogger(__name__)
def main( def main(
num_experiments: int = 2, num_experiments: int = 5,
run_experiments_sequentially: bool = True, run_experiments_sequentially: bool = False,
) -> RLiableExperimentResult: ) -> RLiableExperimentResult:
""":param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. """:param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds.
LIMITATIONS: currently, the parallel execution does not seem to work properly on linux. :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
It might generally be undesired to run multiple experiments in parallel on the same machine, If a single experiment is set to use all available CPU cores,
as a single experiment already uses all available CPU cores by default. it might be undesired to run multiple experiments in parallel on the same machine,
:return: the directory where the results are stored :return: an object containing rliable-based evaluation results
""" """
task = "Ant-v4" task = "Ant-v4"
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))
@ -57,12 +55,12 @@ def main(
num_epochs=1, num_epochs=1,
step_per_epoch=5000, step_per_epoch=5000,
batch_size=64, batch_size=64,
num_train_envs=10, num_train_envs=5,
num_test_envs=10, num_test_envs=5,
num_test_episodes=10, num_test_episodes=5,
buffer_size=4096, buffer_size=4096,
step_per_collect=2048, step_per_collect=2048,
repeat_per_collect=10, repeat_per_collect=1,
) )
env_factory = MujocoEnvFactory( env_factory = MujocoEnvFactory(
@ -70,9 +68,6 @@ def main(
train_seed=sampling_config.train_seed, train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed, test_seed=sampling_config.test_seed,
obs_norm=True, obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
if sys.platform == "darwin"
else VectorEnvType.SUBPROC_SHARED_MEM,
) )
hidden_sizes = (64, 64) hidden_sizes = (64, 64)
@ -108,7 +103,18 @@ def main(
launcher = RegisteredExpLauncher.sequential.create_launcher() launcher = RegisteredExpLauncher.sequential.create_launcher()
else: else:
launcher = RegisteredExpLauncher.joblib.create_launcher() launcher = RegisteredExpLauncher.joblib.create_launcher()
experiment_collection.run(launcher) successful_experiment_stats = experiment_collection.run(launcher)
log.info(f"Successfully completed {len(successful_experiment_stats)} experiments.")
num_successful_experiments = len(successful_experiment_stats)
for i, info_stats in enumerate(successful_experiment_stats, start=1):
if info_stats is not None:
log.info(f"Training stats for successful experiment {i}/{num_successful_experiments}:")
log.info(info_stats.pprints_asdict())
else:
log.info(
f"No training stats available for successful experiment {i}/{num_successful_experiments}.",
)
rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir) rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
rliable_result.eval_results(show_plots=True, save_plots=True) rliable_result.eval_results(show_plots=True, save_plots=True)

View File

@ -2,7 +2,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Callable, Sequence
from copy import copy from copy import copy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
@ -10,6 +10,7 @@ from typing import Literal
from joblib import Parallel, delayed from joblib import Parallel, delayed
from tianshou.data import InfoStats
from tianshou.highlevel.experiment import Experiment from tianshou.highlevel.experiment import Experiment
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -26,19 +27,89 @@ class JoblibConfig:
class ExpLauncher(ABC): class ExpLauncher(ABC):
def __init__(
self,
experiment_runner: Callable[
[Experiment],
InfoStats | None,
] = lambda exp: exp.run().trainer_result,
):
""":param experiment_runner: can be used to override the default way in which an experiment is executed.
Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment
collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms
that are not yet supported by the high-level interfaces.
Passing this allows arbitrary things to happen during experiment execution, so use it with caution!
"""
self.experiment_runner = experiment_runner
@abstractmethod @abstractmethod
def launch(self, experiments: Sequence[Experiment]) -> None: def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
pass """Should call `self.experiment_runner` for each experiment in experiments and aggregate the results."""
def _safe_execute(self, exp: Experiment) -> InfoStats | None | Literal["failed"]:
try:
return self.experiment_runner(exp)
except BaseException as e:
log.error(f"Failed to run experiment {exp}.", exc_info=e)
return "failed"
@staticmethod
def _return_from_successful_and_failed_exps(
successful_exp_stats: list[InfoStats | None],
failed_exps: list[Experiment],
) -> list[InfoStats | None]:
if not successful_exp_stats:
raise RuntimeError("All experiments failed, see error logs for more details.")
if failed_exps:
log.error(
f"Failed to run the following "
f"{len(failed_exps)}/{len(successful_exp_stats) + len(failed_exps)} experiments: {failed_exps}. "
f"See the logs for more details. "
f"Returning the results of {len(successful_exp_stats)} successful experiments.",
)
return successful_exp_stats
def launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
"""Will return the results of successfully executed experiments.
If a single experiment is passed, will not use parallelism and run it in the main process.
Failed experiments will be logged, and a RuntimeError is only raised if all experiments have failed.
"""
if len(experiments) == 1:
log.info(
"A single experiment is being run, will not use parallelism and run it in the main process.",
)
return [self.experiment_runner(experiments[0])]
return self._launch(experiments)
class SequentialExpLauncher(ExpLauncher): class SequentialExpLauncher(ExpLauncher):
def launch(self, experiments: Sequence[Experiment]) -> None: """Convenience wrapper around a simple for loop to run experiments sequentially."""
def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
successful_exp_stats = []
failed_exps = []
for exp in experiments: for exp in experiments:
exp.run() for exp in experiments:
exp_stats = self._safe_execute(exp)
if exp_stats == "failed":
failed_exps.append(exp)
else:
successful_exp_stats.append(exp_stats)
# noinspection PyTypeChecker
return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps)
class JoblibExpLauncher(ExpLauncher): class JoblibExpLauncher(ExpLauncher):
def __init__(self, joblib_cfg: JoblibConfig | None = None) -> None: def __init__(
self,
joblib_cfg: JoblibConfig | None = None,
experiment_runner: Callable[
[Experiment],
InfoStats | None,
] = lambda exp: exp.run().trainer_result,
) -> None:
super().__init__(experiment_runner=experiment_runner)
self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig() self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
# Joblib's backend is hard-coded to loky since the threading backend produces different results # Joblib's backend is hard-coded to loky since the threading backend produces different results
if self.joblib_cfg.backend != "loky": if self.joblib_cfg.backend != "loky":
@ -48,15 +119,18 @@ class JoblibExpLauncher(ExpLauncher):
) )
self.joblib_cfg.backend = "loky" self.joblib_cfg.backend = "loky"
def launch(self, experiments: Sequence[Experiment]) -> None: def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]:
Parallel(**asdict(self.joblib_cfg))(delayed(self._safe_execute)(exp) for exp in experiments) results = Parallel(**asdict(self.joblib_cfg))(
delayed(self._safe_execute)(exp) for exp in experiments
@staticmethod )
def _safe_execute(exp: Experiment) -> None: successful_exps = []
try: failed_exps = []
exp.run() for exp, result in zip(experiments, results, strict=True):
except BaseException as e: if result == "failed":
log.error(e) failed_exps.append(exp)
else:
successful_exps.append(result)
return self._return_from_successful_and_failed_exps(successful_exps, failed_exps)
class RegisteredExpLauncher(Enum): class RegisteredExpLauncher(Enum):

View File

@ -85,7 +85,7 @@ class SamplingConfig(ToStringMixin):
""" """
controls, within one gradient update step of an on-policy algorithm, the number of times an controls, within one gradient update step of an on-policy algorithm, the number of times an
actual gradient update is applied using the full collected dataset, i.e. if the parameter is actual gradient update is applied using the full collected dataset, i.e. if the parameter is
`n`, then the collected data shall be used five times to update the policy within the same 5, then the collected data shall be used five times to update the policy within the same
training step. training step.
The parameter is ignored and may be set to None for off-policy and offline algorithms. The parameter is ignored and may be set to None for off-policy and offline algorithms.

View File

@ -1,4 +1,5 @@
import logging import logging
import platform
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from enum import Enum from enum import Enum
@ -66,13 +67,15 @@ class VectorEnvType(Enum):
"""Vectorized environment without parallelization; environments are processed sequentially""" """Vectorized environment without parallelization; environments are processed sequentially"""
SUBPROC = "subproc" SUBPROC = "subproc"
"""Parallelization based on `subprocess`""" """Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM = "shmem" SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem"
"""Parallelization based on `subprocess` with shared memory""" """Parallelization based on `subprocess` with shared memory"""
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork" SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn` """Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)""" by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
RAY = "ray" RAY = "ray"
"""Parallelization based on the `ray` library""" """Parallelization based on the `ray` library"""
SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto"
"""Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise"""
def create_venv( def create_venv(
self, self,
@ -83,10 +86,16 @@ class VectorEnvType(Enum):
return DummyVectorEnv(factories) return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC: case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories) return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM: case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True) return SubprocVectorEnv(factories, share_memory=True)
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT: case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True, context="fork") return SubprocVectorEnv(factories, share_memory=True, context="fork")
case VectorEnvType.SUBPROC_SHARED_MEM_AUTO:
if platform.system().lower() == "windows":
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT
else:
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
return selected_venv_type.create_venv(factories)
case VectorEnvType.RAY: case VectorEnvType.RAY:
return RayVectorEnv(factories) return RayVectorEnv(factories)
case _: case _:

View File

@ -87,7 +87,6 @@ from tianshou.utils.string import ToStringMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -356,15 +355,20 @@ class Experiment(ToStringMixin):
class ExperimentCollection: class ExperimentCollection:
"""Shallow wrapper around a list of experiments providing a simple interface for running them with a launcher."""
def __init__(self, experiments: list[Experiment]): def __init__(self, experiments: list[Experiment]):
self.experiments = experiments self.experiments = experiments
def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None: def run(
self,
launcher: Union["ExpLauncher", "RegisteredExpLauncher"],
) -> list[InfoStats | None]:
from tianshou.evaluation.launcher import RegisteredExpLauncher from tianshou.evaluation.launcher import RegisteredExpLauncher
if isinstance(launcher, RegisteredExpLauncher): if isinstance(launcher, RegisteredExpLauncher):
launcher = launcher.create_launcher() launcher = launcher.create_launcher()
launcher.launch(experiments=self.experiments) return launcher.launch(experiments=self.experiments)
class ExperimentBuilder: class ExperimentBuilder:

View File

@ -1,20 +1,29 @@
import pprint
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from pprint import pprint
@dataclass @dataclass
class DataclassPPrintMixin: class DataclassPPrintMixin:
def pprint_asdict(self, exclude_fields: Sequence[str] | None = None) -> None: def pprint_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> None:
"""Pretty-print the object as a dict, excluding specified fields. """Pretty-print the object as a dict, excluding specified fields.
:param exclude_fields: A sequence of field names to exclude from the output. :param exclude_fields: A sequence of field names to exclude from the output.
If None, no fields are excluded. If None, no fields are excluded.
:param indent: The indentation to use when pretty-printing.
""" """
print(f"{self.__class__.__name__}") print(self.pprints_asdict(exclude_fields=exclude_fields, indent=indent))
print("----------------------------------------")
def pprints_asdict(self, exclude_fields: Sequence[str] | None = None, indent: int = 4) -> str:
"""String corresponding to pretty-print of the object as a dict, excluding specified fields.
:param exclude_fields: A sequence of field names to exclude from the output.
If None, no fields are excluded.
:param indent: The indentation to use when pretty-printing.
"""
prefix = f"{self.__class__.__name__}\n----------------------------------------\n"
print_dict = asdict(self) print_dict = asdict(self)
exclude_fields = exclude_fields or [] exclude_fields = exclude_fields or []
for field in exclude_fields: for field in exclude_fields:
print_dict.pop(field, None) print_dict.pop(field, None)
pprint(print_dict) return prefix + pprint.pformat(print_dict, indent=indent)