Allow explicit setting of multiprocessing context for SubprocEnvWorker (#1072)

Running multiple training runs in parallel (with, for example, joblib)
fails on macOS due to a change in the standard context for
multiprocessing (see
[here](https://stackoverflow.com/questions/65098398/why-using-fork-works-but-using-spawn-fails-in-python3-8-multiprocessing)
or
[here](https://www.reddit.com/r/learnpython/comments/g5372v/multiprocessing_with_fork_on_macos/)).
This PR adds the ability to explicitly set a multiprocessing context for
the SubProcEnvWorker (similar to gymnasium's
[AsyncVecEnv](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/vector/async_vector_env.py)).
---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
This commit is contained in:
maxhuettenrauch 2024-03-14 11:07:56 +01:00 committed by GitHub
parent 1714c7f2c7
commit e82379c47f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 34 deletions

View File

@ -253,4 +253,7 @@ Dominik
Tsinghua Tsinghua
Tianshou Tianshou
appliedAI appliedAI
macOS
joblib
master
Panchenko Panchenko

View File

@ -62,11 +62,17 @@ class MujocoEnvObsRmsPersistence(Persistence):
class MujocoEnvFactory(EnvFactoryRegistered): class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(self, task: str, seed: int, obs_norm=True) -> None: def __init__(
self,
task: str,
seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
) -> None:
super().__init__( super().__init__(
task=task, task=task,
seed=seed, seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM, venv_type=venv_type,
envpool_factory=EnvPoolFactory() if envpool_is_available else None, envpool_factory=EnvPoolFactory() if envpool_is_available else None,
) )
self.obs_norm = obs_norm self.obs_norm = obs_norm

58
tianshou/env/venvs.py vendored
View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any from typing import Any, Literal
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -371,8 +371,13 @@ class DummyVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
""" """
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: def __init__(
super().__init__(env_fns, DummyEnvWorker, **kwargs) self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
) -> None:
super().__init__(env_fns, DummyEnvWorker, wait_num, timeout)
class SubprocVectorEnv(BaseVectorEnv): class SubprocVectorEnv(BaseVectorEnv):
@ -381,13 +386,36 @@ class SubprocVectorEnv(BaseVectorEnv):
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
Additional arguments are:
:param share_memory: whether to share memory between the main process and the worker process. Allows for
shared buffers to exchange observations
:param context: the context to use for multiprocessing. Usually it's fine to use the default context, but
`spawn` as well as `fork` can have non-obvious side effects, see for example
https://github.com/google-deepmind/mujoco/issues/742, or
https://github.com/Farama-Foundation/Gymnasium/issues/222.
Consider using 'fork' when using macOS and additional parallelization, for example via joblib.
Defaults to None, which will use the default system context.
""" """
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: def __init__(
self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
share_memory: bool = False,
context: Literal["fork", "spawn"] | None = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False) return SubprocEnvWorker(fn, share_memory=share_memory, context=context)
super().__init__(env_fns, worker_fn, **kwargs) super().__init__(
env_fns,
worker_fn,
wait_num,
timeout,
)
class ShmemVectorEnv(BaseVectorEnv): class ShmemVectorEnv(BaseVectorEnv):
@ -400,11 +428,16 @@ class ShmemVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
""" """
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: def __init__(
self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True) return SubprocEnvWorker(fn, share_memory=True)
super().__init__(env_fns, worker_fn, **kwargs) super().__init__(env_fns, worker_fn, wait_num, timeout)
class RayVectorEnv(BaseVectorEnv): class RayVectorEnv(BaseVectorEnv):
@ -417,7 +450,12 @@ class RayVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
""" """
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: def __init__(
self,
env_fns: Sequence[Callable[[], ENV_TYPE]],
wait_num: int | None = None,
timeout: float | None = None,
) -> None:
try: try:
import ray import ray
except ImportError as exception: except ImportError as exception:
@ -426,4 +464,4 @@ class RayVectorEnv(BaseVectorEnv):
) from exception ) from exception
if not ray.is_initialized(): if not ray.is_initialized():
ray.init() ray.init()
super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), **kwargs) super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), wait_num, timeout)

View File

@ -1,10 +1,11 @@
import ctypes import ctypes
import multiprocessing
import time import time
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from multiprocessing import Array, Pipe, connection from multiprocessing import Pipe, connection
from multiprocessing.context import Process from multiprocessing.context import BaseContext
from typing import Any from typing import Any, Literal
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -31,10 +32,26 @@ _NP_TO_CT = {
class ShArray: class ShArray:
"""Wrapper of multiprocessing Array.""" """Wrapper of multiprocessing Array.
def __init__(self, dtype: np.generic, shape: tuple[int]) -> None: Example usage:
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
::
import numpy as np
import multiprocessing as mp
from tianshou.env.worker.subproc import ShArray
ctx = mp.get_context('fork') # set an explicit context
arr = ShArray(np.dtype(np.float32), (2, 3), ctx)
arr.save(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
print(arr.get())
"""
def __init__(self, dtype: np.generic, shape: tuple[int], ctx: BaseContext | None) -> None:
if ctx is None:
ctx = multiprocessing.get_context()
self.arr = ctx.Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
self.dtype = dtype self.dtype = dtype
self.shape = shape self.shape = shape
@ -49,14 +66,14 @@ class ShArray:
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore
def _setup_buf(space: gym.Space) -> dict | tuple | ShArray: def _setup_buf(space: gym.Space, ctx: BaseContext) -> dict | tuple | ShArray:
if isinstance(space, gym.spaces.Dict): if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict) assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()} return {k: _setup_buf(v, ctx) for k, v in space.spaces.items()}
if isinstance(space, gym.spaces.Tuple): if isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple) assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces]) return tuple([_setup_buf(t, ctx) for t in space.spaces])
return ShArray(space.dtype, space.shape) # type: ignore return ShArray(space.dtype, space.shape, ctx) # type: ignore
def _worker( def _worker(
@ -125,23 +142,31 @@ def _worker(
class SubprocEnvWorker(EnvWorker): class SubprocEnvWorker(EnvWorker):
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None: def __init__(
self,
env_fn: Callable[[], gym.Env],
share_memory: bool = False,
context: BaseContext | Literal["fork", "spawn"] | None = None,
) -> None:
self.parent_remote, self.child_remote = Pipe() self.parent_remote, self.child_remote = Pipe()
self.share_memory = share_memory self.share_memory = share_memory
self.buffer: dict | tuple | ShArray | None = None self.buffer: dict | tuple | ShArray | None = None
if not isinstance(context, BaseContext):
context = multiprocessing.get_context(context)
assert hasattr(context, "Process") # for mypy
if self.share_memory: if self.share_memory:
dummy = env_fn() dummy = env_fn()
obs_space = dummy.observation_space obs_space = dummy.observation_space
dummy.close() dummy.close()
del dummy del dummy
self.buffer = _setup_buf(obs_space) self.buffer = _setup_buf(obs_space, context)
args = ( args = (
self.parent_remote, self.parent_remote,
self.child_remote, self.child_remote,
CloudpickleWrapper(env_fn), CloudpickleWrapper(env_fn),
self.buffer, self.buffer,
) )
self.process = Process(target=_worker, args=args, daemon=True) self.process = context.Process(target=_worker, args=args, daemon=True)
self.process.start() self.process.start()
self.child_remote.close() self.child_remote.close()
super().__init__(env_fn) super().__init__(env_fn)

View File

@ -12,7 +12,6 @@ from tianshou.env import (
BaseVectorEnv, BaseVectorEnv,
DummyVectorEnv, DummyVectorEnv,
RayVectorEnv, RayVectorEnv,
ShmemVectorEnv,
SubprocVectorEnv, SubprocVectorEnv,
) )
from tianshou.highlevel.persistence import Persistence from tianshou.highlevel.persistence import Persistence
@ -69,17 +68,25 @@ class VectorEnvType(Enum):
"""Parallelization based on `subprocess`""" """Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM = "shmem" SUBPROC_SHARED_MEM = "shmem"
"""Parallelization based on `subprocess` with shared memory""" """Parallelization based on `subprocess` with shared memory"""
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
RAY = "ray" RAY = "ray"
"""Parallelization based on the `ray` library""" """Parallelization based on the `ray` library"""
def create_venv(self, factories: Sequence[Callable[[], gym.Env]]) -> BaseVectorEnv: def create_venv(
self,
factories: Sequence[Callable[[], gym.Env]],
) -> BaseVectorEnv:
match self: match self:
case VectorEnvType.DUMMY: case VectorEnvType.DUMMY:
return DummyVectorEnv(factories) return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC: case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories) return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM: case VectorEnvType.SUBPROC_SHARED_MEM:
return ShmemVectorEnv(factories) return SubprocVectorEnv(factories, share_memory=True)
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True, context="fork")
case VectorEnvType.RAY: case VectorEnvType.RAY:
return RayVectorEnv(factories) return RayVectorEnv(factories)
case _: case _:
@ -121,10 +128,14 @@ class Environments(ToStringMixin, ABC):
:param create_watch_env: whether to create an environment for watching the agent :param create_watch_env: whether to create an environment for watching the agent
:return: the instance :return: the instance
""" """
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) train_envs = venv_type.create_venv(
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) [lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs,
)
test_envs = venv_type.create_venv(
[lambda: factory_fn(EnvMode.TEST)] * num_test_envs,
)
if create_watch_env: if create_watch_env:
watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)]) watch_env = VectorEnvType.DUMMY.create_venv([lambda: factory_fn(EnvMode.WATCH)])
else: else:
watch_env = None watch_env = None
env = factory_fn(EnvMode.TRAIN) env = factory_fn(EnvMode.TRAIN)
@ -344,7 +355,9 @@ class EnvFactory(ToStringMixin, ABC):
"""Main interface for the creation of environments (in various forms).""" """Main interface for the creation of environments (in various forms)."""
def __init__(self, venv_type: VectorEnvType): def __init__(self, venv_type: VectorEnvType):
""":param venv_type: the type of vectorized environment to use""" """:param venv_type: the type of vectorized environment to use for train and test environments.
watch environments are always created as dummy environments.
"""
self.venv_type = venv_type self.venv_type = venv_type
@abstractmethod @abstractmethod
@ -355,10 +368,14 @@ class EnvFactory(ToStringMixin, ABC):
"""Create vectorized environments. """Create vectorized environments.
:param num_envs: the number of environments :param num_envs: the number of environments
:param mode: the mode for which to create :param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env.
:return: the vectorized environments :return: the vectorized environments
""" """
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) if mode == EnvMode.WATCH:
return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)])
else:
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
def create_envs( def create_envs(
self, self,