diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index d3cd95a..63ee791 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -253,4 +253,7 @@ Dominik Tsinghua Tianshou appliedAI +macOS +joblib +master Panchenko diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index b04f243..dacf915 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -62,11 +62,17 @@ class MujocoEnvObsRmsPersistence(Persistence): class MujocoEnvFactory(EnvFactoryRegistered): - def __init__(self, task: str, seed: int, obs_norm=True) -> None: + def __init__( + self, + task: str, + seed: int, + obs_norm: bool = True, + venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM, + ) -> None: super().__init__( task=task, seed=seed, - venv_type=VectorEnvType.SUBPROC_SHARED_MEM, + venv_type=venv_type, envpool_factory=EnvPoolFactory() if envpool_is_available else None, ) self.obs_norm = obs_norm diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 08f7650..dfcd12e 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, Literal import gymnasium as gym import numpy as np @@ -371,8 +371,13 @@ class DummyVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: - super().__init__(env_fns, DummyEnvWorker, **kwargs) + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + ) -> None: + super().__init__(env_fns, DummyEnvWorker, wait_num, timeout) class SubprocVectorEnv(BaseVectorEnv): @@ -381,13 +386,36 @@ class SubprocVectorEnv(BaseVectorEnv): .. seealso:: Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. + + Additional arguments are: + + :param share_memory: whether to share memory between the main process and the worker process. Allows for + shared buffers to exchange observations + :param context: the context to use for multiprocessing. Usually it's fine to use the default context, but + `spawn` as well as `fork` can have non-obvious side effects, see for example + https://github.com/google-deepmind/mujoco/issues/742, or + https://github.com/Farama-Foundation/Gymnasium/issues/222. + Consider using 'fork' when using macOS and additional parallelization, for example via joblib. + Defaults to None, which will use the default system context. """ - def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + share_memory: bool = False, + context: Literal["fork", "spawn"] | None = None, + ) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: - return SubprocEnvWorker(fn, share_memory=False) + return SubprocEnvWorker(fn, share_memory=share_memory, context=context) - super().__init__(env_fns, worker_fn, **kwargs) + super().__init__( + env_fns, + worker_fn, + wait_num, + timeout, + ) class ShmemVectorEnv(BaseVectorEnv): @@ -400,11 +428,16 @@ class ShmemVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + ) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) - super().__init__(env_fns, worker_fn, **kwargs) + super().__init__(env_fns, worker_fn, wait_num, timeout) class RayVectorEnv(BaseVectorEnv): @@ -417,7 +450,12 @@ class RayVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: + def __init__( + self, + env_fns: Sequence[Callable[[], ENV_TYPE]], + wait_num: int | None = None, + timeout: float | None = None, + ) -> None: try: import ray except ImportError as exception: @@ -426,4 +464,4 @@ class RayVectorEnv(BaseVectorEnv): ) from exception if not ray.is_initialized(): ray.init() - super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), **kwargs) + super().__init__(env_fns, lambda env_fn: RayEnvWorker(env_fn), wait_num, timeout) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index af5ec4e..ab84ac0 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -1,10 +1,11 @@ import ctypes +import multiprocessing import time from collections import OrderedDict from collections.abc import Callable -from multiprocessing import Array, Pipe, connection -from multiprocessing.context import Process -from typing import Any +from multiprocessing import Pipe, connection +from multiprocessing.context import BaseContext +from typing import Any, Literal import gymnasium as gym import numpy as np @@ -31,10 +32,26 @@ _NP_TO_CT = { class ShArray: - """Wrapper of multiprocessing Array.""" + """Wrapper of multiprocessing Array. - def __init__(self, dtype: np.generic, shape: tuple[int]) -> None: - self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore + Example usage: + + :: + + import numpy as np + import multiprocessing as mp + from tianshou.env.worker.subproc import ShArray + ctx = mp.get_context('fork') # set an explicit context + arr = ShArray(np.dtype(np.float32), (2, 3), ctx) + arr.save(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)) + print(arr.get()) + + """ + + def __init__(self, dtype: np.generic, shape: tuple[int], ctx: BaseContext | None) -> None: + if ctx is None: + ctx = multiprocessing.get_context() + self.arr = ctx.Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore self.dtype = dtype self.shape = shape @@ -49,14 +66,14 @@ class ShArray: return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) # type: ignore -def _setup_buf(space: gym.Space) -> dict | tuple | ShArray: +def _setup_buf(space: gym.Space, ctx: BaseContext) -> dict | tuple | ShArray: if isinstance(space, gym.spaces.Dict): assert isinstance(space.spaces, OrderedDict) - return {k: _setup_buf(v) for k, v in space.spaces.items()} + return {k: _setup_buf(v, ctx) for k, v in space.spaces.items()} if isinstance(space, gym.spaces.Tuple): assert isinstance(space.spaces, tuple) - return tuple([_setup_buf(t) for t in space.spaces]) - return ShArray(space.dtype, space.shape) # type: ignore + return tuple([_setup_buf(t, ctx) for t in space.spaces]) + return ShArray(space.dtype, space.shape, ctx) # type: ignore def _worker( @@ -125,23 +142,31 @@ def _worker( class SubprocEnvWorker(EnvWorker): """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" - def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None: + def __init__( + self, + env_fn: Callable[[], gym.Env], + share_memory: bool = False, + context: BaseContext | Literal["fork", "spawn"] | None = None, + ) -> None: self.parent_remote, self.child_remote = Pipe() self.share_memory = share_memory self.buffer: dict | tuple | ShArray | None = None + if not isinstance(context, BaseContext): + context = multiprocessing.get_context(context) + assert hasattr(context, "Process") # for mypy if self.share_memory: dummy = env_fn() obs_space = dummy.observation_space dummy.close() del dummy - self.buffer = _setup_buf(obs_space) + self.buffer = _setup_buf(obs_space, context) args = ( self.parent_remote, self.child_remote, CloudpickleWrapper(env_fn), self.buffer, ) - self.process = Process(target=_worker, args=args, daemon=True) + self.process = context.Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() super().__init__(env_fn) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 71de0f8..bd53975 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -12,7 +12,6 @@ from tianshou.env import ( BaseVectorEnv, DummyVectorEnv, RayVectorEnv, - ShmemVectorEnv, SubprocVectorEnv, ) from tianshou.highlevel.persistence import Persistence @@ -69,17 +68,25 @@ class VectorEnvType(Enum): """Parallelization based on `subprocess`""" SUBPROC_SHARED_MEM = "shmem" """Parallelization based on `subprocess` with shared memory""" + SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork" + """Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn` + by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)""" RAY = "ray" """Parallelization based on the `ray` library""" - def create_venv(self, factories: Sequence[Callable[[], gym.Env]]) -> BaseVectorEnv: + def create_venv( + self, + factories: Sequence[Callable[[], gym.Env]], + ) -> BaseVectorEnv: match self: case VectorEnvType.DUMMY: return DummyVectorEnv(factories) case VectorEnvType.SUBPROC: return SubprocVectorEnv(factories) case VectorEnvType.SUBPROC_SHARED_MEM: - return ShmemVectorEnv(factories) + return SubprocVectorEnv(factories, share_memory=True) + case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT: + return SubprocVectorEnv(factories, share_memory=True, context="fork") case VectorEnvType.RAY: return RayVectorEnv(factories) case _: @@ -121,10 +128,14 @@ class Environments(ToStringMixin, ABC): :param create_watch_env: whether to create an environment for watching the agent :return: the instance """ - train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) - test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) + train_envs = venv_type.create_venv( + [lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs, + ) + test_envs = venv_type.create_venv( + [lambda: factory_fn(EnvMode.TEST)] * num_test_envs, + ) if create_watch_env: - watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)]) + watch_env = VectorEnvType.DUMMY.create_venv([lambda: factory_fn(EnvMode.WATCH)]) else: watch_env = None env = factory_fn(EnvMode.TRAIN) @@ -344,7 +355,9 @@ class EnvFactory(ToStringMixin, ABC): """Main interface for the creation of environments (in various forms).""" def __init__(self, venv_type: VectorEnvType): - """:param venv_type: the type of vectorized environment to use""" + """:param venv_type: the type of vectorized environment to use for train and test environments. + watch environments are always created as dummy environments. + """ self.venv_type = venv_type @abstractmethod @@ -355,10 +368,14 @@ class EnvFactory(ToStringMixin, ABC): """Create vectorized environments. :param num_envs: the number of environments - :param mode: the mode for which to create + :param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env. + :return: the vectorized environments """ - return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) + if mode == EnvMode.WATCH: + return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)]) + else: + return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) def create_envs( self,