Allow explicit setting of multiprocessing context for SubprocEnvWorker (#1072)
Running multiple training runs in parallel (with, for example, joblib) fails on macOS due to a change in the standard context for multiprocessing (see [here](https://stackoverflow.com/questions/65098398/why-using-fork-works-but-using-spawn-fails-in-python3-8-multiprocessing) or [here](https://www.reddit.com/r/learnpython/comments/g5372v/multiprocessing_with_fork_on_macos/)). This PR adds the ability to explicitly set a multiprocessing context for the SubProcEnvWorker (similar to gymnasium's [AsyncVecEnv](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/vector/async_vector_env.py)). --------- Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de> Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
This commit is contained in:
parent
1714c7f2c7
commit
e82379c47f
@ -253,4 +253,7 @@ Dominik
|
|||||||
Tsinghua
|
Tsinghua
|
||||||
Tianshou
|
Tianshou
|
||||||
appliedAI
|
appliedAI
|
||||||
|
macOS
|
||||||
|
joblib
|
||||||
|
master
|
||||||
Panchenko
|
Panchenko
|
||||||
|
@ -62,11 +62,17 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
|||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactoryRegistered):
|
class MujocoEnvFactory(EnvFactoryRegistered):
|
||||||
def __init__(self, task: str, seed: int, obs_norm=True) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
seed: int,
|
||||||
|
obs_norm: bool = True,
|
||||||
|
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
|
||||||
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task=task,
|
task=task,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
venv_type=venv_type,
|
||||||
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
|
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
|
||||||
)
|
)
|
||||||
self.obs_norm = obs_norm
|
self.obs_norm = obs_norm
|
||||||
|
58
tianshou/env/venvs.py
vendored
58
tianshou/env/venvs.py
vendored
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -371,8 +371,13 @@ class DummyVectorEnv(BaseVectorEnv):
|
|||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
|
def __init__(
|
||||||
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
self,
|
||||||
|
env_fns: Sequence[Callable[[], ENV_TYPE]],
|
||||||
|
wait_num: int | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(env_fns, DummyEnvWorker, wait_num, timeout)
|
||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
@ -381,13 +386,36 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
|
|
||||||
|
Additional arguments are:
|
||||||
|
|
||||||
|
:param share_memory: whether to share memory between the main process and the worker process. Allows for
|
||||||
|
shared buffers to exchange observations
|
||||||
|
:param context: the context to use for multiprocessing. Usually it's fine to use the default context, but
|
||||||
|
`spawn` as well as `fork` can have non-obvious side effects, see for example
|
||||||
|
https://github.com/google-deepmind/mujoco/issues/742, or
|
||||||
|
https://github.com/Farama-Foundation/Gymnasium/issues/222.
|
||||||
|
Consider using 'fork' when using macOS and additional parallelization, for example via joblib.
|
||||||
|
Defaults to None, which will use the default system context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_fns: Sequence[Callable[[], ENV_TYPE]],
|
||||||
|
wait_num: int | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
share_memory: bool = False,
|
||||||
|
context: Literal["fork", "spawn"] | None = None,
|
||||||
|
) -> None:
|
||||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||||
return SubprocEnvWorker(fn, share_memory=False)
|
return SubprocEnvWorker(fn, share_memory=share_memory, context=context)
|
||||||
|
|
||||||
super().__init__(env_fns, worker_fn, **kwargs)
|
super().__init__(
|
||||||
|
env_fns,
|
||||||
|
worker_fn,
|
||||||
|
wait_num,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShmemVectorEnv(BaseVectorEnv):
|
class ShmemVectorEnv(BaseVectorEnv):
|
||||||
@ -400,11 +428,16 @@ class ShmemVectorEnv(BaseVectorEnv):
|
|||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_fns: Sequence[Callable[[], ENV_TYPE]],
|
||||||
|
wait_num: int | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> None:
|
||||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||||
return SubprocEnvWorker(fn, share_memory=True)
|
return SubprocEnvWorker(fn, share_memory=True)
|
||||||
|
|
||||||
super().__init__(env_fns, worker_fn, **kwargs)
|
super().__init__(env_fns, worker_fn, wait_num, timeout)
|
||||||
|
|
||||||
|
|
||||||
class RayVectorEnv(BaseVectorEnv):
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
@ -417,7 +450,12 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_fns: Sequence[Callable[[], ENV_TYPE]],
|
||||||
|
wait_num: int | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
except ImportError as exception:
|
except ImportError as exception:
|
||||||
@ -426,4 +464,4 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
) from exception
|
) from exception
|
||||||
if not ray.is_initialized():
|
if not ray.is_initialized():
|
||||||
ray.init()
|
ray.init()
|
||||||
super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), **kwargs)
|
super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), wait_num, timeout)
|
||||||
|
51
tianshou/env/worker/subproc.py
vendored
51
tianshou/env/worker/subproc.py
vendored
@ -1,10 +1,11 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from multiprocessing import Array, Pipe, connection
|
from multiprocessing import Pipe, connection
|
||||||
from multiprocessing.context import Process
|
from multiprocessing.context import BaseContext
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -31,10 +32,26 @@ _NP_TO_CT = {
|
|||||||
|
|
||||||
|
|
||||||
class ShArray:
|
class ShArray:
|
||||||
"""Wrapper of multiprocessing Array."""
|
"""Wrapper of multiprocessing Array.
|
||||||
|
|
||||||
def __init__(self, dtype: np.generic, shape: tuple[int]) -> None:
|
Example usage:
|
||||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
|
|
||||||
|
::
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import multiprocessing as mp
|
||||||
|
from tianshou.env.worker.subproc import ShArray
|
||||||
|
ctx = mp.get_context('fork') # set an explicit context
|
||||||
|
arr = ShArray(np.dtype(np.float32), (2, 3), ctx)
|
||||||
|
arr.save(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
|
||||||
|
print(arr.get())
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dtype: np.generic, shape: tuple[int], ctx: BaseContext | None) -> None:
|
||||||
|
if ctx is None:
|
||||||
|
ctx = multiprocessing.get_context()
|
||||||
|
self.arr = ctx.Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
@ -49,14 +66,14 @@ class ShArray:
|
|||||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore
|
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _setup_buf(space: gym.Space) -> dict | tuple | ShArray:
|
def _setup_buf(space: gym.Space, ctx: BaseContext) -> dict | tuple | ShArray:
|
||||||
if isinstance(space, gym.spaces.Dict):
|
if isinstance(space, gym.spaces.Dict):
|
||||||
assert isinstance(space.spaces, OrderedDict)
|
assert isinstance(space.spaces, OrderedDict)
|
||||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
return {k: _setup_buf(v, ctx) for k, v in space.spaces.items()}
|
||||||
if isinstance(space, gym.spaces.Tuple):
|
if isinstance(space, gym.spaces.Tuple):
|
||||||
assert isinstance(space.spaces, tuple)
|
assert isinstance(space.spaces, tuple)
|
||||||
return tuple([_setup_buf(t) for t in space.spaces])
|
return tuple([_setup_buf(t, ctx) for t in space.spaces])
|
||||||
return ShArray(space.dtype, space.shape) # type: ignore
|
return ShArray(space.dtype, space.shape, ctx) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _worker(
|
def _worker(
|
||||||
@ -125,23 +142,31 @@ def _worker(
|
|||||||
class SubprocEnvWorker(EnvWorker):
|
class SubprocEnvWorker(EnvWorker):
|
||||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||||
|
|
||||||
def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_fn: Callable[[], gym.Env],
|
||||||
|
share_memory: bool = False,
|
||||||
|
context: BaseContext | Literal["fork", "spawn"] | None = None,
|
||||||
|
) -> None:
|
||||||
self.parent_remote, self.child_remote = Pipe()
|
self.parent_remote, self.child_remote = Pipe()
|
||||||
self.share_memory = share_memory
|
self.share_memory = share_memory
|
||||||
self.buffer: dict | tuple | ShArray | None = None
|
self.buffer: dict | tuple | ShArray | None = None
|
||||||
|
if not isinstance(context, BaseContext):
|
||||||
|
context = multiprocessing.get_context(context)
|
||||||
|
assert hasattr(context, "Process") # for mypy
|
||||||
if self.share_memory:
|
if self.share_memory:
|
||||||
dummy = env_fn()
|
dummy = env_fn()
|
||||||
obs_space = dummy.observation_space
|
obs_space = dummy.observation_space
|
||||||
dummy.close()
|
dummy.close()
|
||||||
del dummy
|
del dummy
|
||||||
self.buffer = _setup_buf(obs_space)
|
self.buffer = _setup_buf(obs_space, context)
|
||||||
args = (
|
args = (
|
||||||
self.parent_remote,
|
self.parent_remote,
|
||||||
self.child_remote,
|
self.child_remote,
|
||||||
CloudpickleWrapper(env_fn),
|
CloudpickleWrapper(env_fn),
|
||||||
self.buffer,
|
self.buffer,
|
||||||
)
|
)
|
||||||
self.process = Process(target=_worker, args=args, daemon=True)
|
self.process = context.Process(target=_worker, args=args, daemon=True)
|
||||||
self.process.start()
|
self.process.start()
|
||||||
self.child_remote.close()
|
self.child_remote.close()
|
||||||
super().__init__(env_fn)
|
super().__init__(env_fn)
|
||||||
|
@ -12,7 +12,6 @@ from tianshou.env import (
|
|||||||
BaseVectorEnv,
|
BaseVectorEnv,
|
||||||
DummyVectorEnv,
|
DummyVectorEnv,
|
||||||
RayVectorEnv,
|
RayVectorEnv,
|
||||||
ShmemVectorEnv,
|
|
||||||
SubprocVectorEnv,
|
SubprocVectorEnv,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.persistence import Persistence
|
from tianshou.highlevel.persistence import Persistence
|
||||||
@ -69,17 +68,25 @@ class VectorEnvType(Enum):
|
|||||||
"""Parallelization based on `subprocess`"""
|
"""Parallelization based on `subprocess`"""
|
||||||
SUBPROC_SHARED_MEM = "shmem"
|
SUBPROC_SHARED_MEM = "shmem"
|
||||||
"""Parallelization based on `subprocess` with shared memory"""
|
"""Parallelization based on `subprocess` with shared memory"""
|
||||||
|
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
|
||||||
|
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
|
||||||
|
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
|
||||||
RAY = "ray"
|
RAY = "ray"
|
||||||
"""Parallelization based on the `ray` library"""
|
"""Parallelization based on the `ray` library"""
|
||||||
|
|
||||||
def create_venv(self, factories: Sequence[Callable[[], gym.Env]]) -> BaseVectorEnv:
|
def create_venv(
|
||||||
|
self,
|
||||||
|
factories: Sequence[Callable[[], gym.Env]],
|
||||||
|
) -> BaseVectorEnv:
|
||||||
match self:
|
match self:
|
||||||
case VectorEnvType.DUMMY:
|
case VectorEnvType.DUMMY:
|
||||||
return DummyVectorEnv(factories)
|
return DummyVectorEnv(factories)
|
||||||
case VectorEnvType.SUBPROC:
|
case VectorEnvType.SUBPROC:
|
||||||
return SubprocVectorEnv(factories)
|
return SubprocVectorEnv(factories)
|
||||||
case VectorEnvType.SUBPROC_SHARED_MEM:
|
case VectorEnvType.SUBPROC_SHARED_MEM:
|
||||||
return ShmemVectorEnv(factories)
|
return SubprocVectorEnv(factories, share_memory=True)
|
||||||
|
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
|
||||||
|
return SubprocVectorEnv(factories, share_memory=True, context="fork")
|
||||||
case VectorEnvType.RAY:
|
case VectorEnvType.RAY:
|
||||||
return RayVectorEnv(factories)
|
return RayVectorEnv(factories)
|
||||||
case _:
|
case _:
|
||||||
@ -121,10 +128,14 @@ class Environments(ToStringMixin, ABC):
|
|||||||
:param create_watch_env: whether to create an environment for watching the agent
|
:param create_watch_env: whether to create an environment for watching the agent
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
|
train_envs = venv_type.create_venv(
|
||||||
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
|
[lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs,
|
||||||
|
)
|
||||||
|
test_envs = venv_type.create_venv(
|
||||||
|
[lambda: factory_fn(EnvMode.TEST)] * num_test_envs,
|
||||||
|
)
|
||||||
if create_watch_env:
|
if create_watch_env:
|
||||||
watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)])
|
watch_env = VectorEnvType.DUMMY.create_venv([lambda: factory_fn(EnvMode.WATCH)])
|
||||||
else:
|
else:
|
||||||
watch_env = None
|
watch_env = None
|
||||||
env = factory_fn(EnvMode.TRAIN)
|
env = factory_fn(EnvMode.TRAIN)
|
||||||
@ -344,7 +355,9 @@ class EnvFactory(ToStringMixin, ABC):
|
|||||||
"""Main interface for the creation of environments (in various forms)."""
|
"""Main interface for the creation of environments (in various forms)."""
|
||||||
|
|
||||||
def __init__(self, venv_type: VectorEnvType):
|
def __init__(self, venv_type: VectorEnvType):
|
||||||
""":param venv_type: the type of vectorized environment to use"""
|
""":param venv_type: the type of vectorized environment to use for train and test environments.
|
||||||
|
watch environments are always created as dummy environments.
|
||||||
|
"""
|
||||||
self.venv_type = venv_type
|
self.venv_type = venv_type
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -355,10 +368,14 @@ class EnvFactory(ToStringMixin, ABC):
|
|||||||
"""Create vectorized environments.
|
"""Create vectorized environments.
|
||||||
|
|
||||||
:param num_envs: the number of environments
|
:param num_envs: the number of environments
|
||||||
:param mode: the mode for which to create
|
:param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env.
|
||||||
|
|
||||||
:return: the vectorized environments
|
:return: the vectorized environments
|
||||||
"""
|
"""
|
||||||
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
if mode == EnvMode.WATCH:
|
||||||
|
return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)])
|
||||||
|
else:
|
||||||
|
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
||||||
|
|
||||||
def create_envs(
|
def create_envs(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user