From 1cd22f1d32ef88c95c4a0f5808d3cdbe7b9076a3 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 6 May 2024 21:22:39 +0200 Subject: [PATCH] Added and used new VenvType: SUBPROC_SHARED_MEM_AUTO --- examples/atari/atari_wrapper.py | 3 ++- examples/mujoco/mujoco_env.py | 2 +- examples/mujoco/mujoco_ppo_hl_multi.py | 6 ++---- tianshou/highlevel/env.py | 13 +++++++++++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a5135f5..d7234d8 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -397,6 +397,7 @@ class AtariEnvFactory(EnvFactoryRegistered): frame_stack: int, scale: bool = False, use_envpool_if_available: bool = True, + venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO, ) -> None: assert "NoFrameskip" in task self.frame_stack = frame_stack @@ -412,7 +413,7 @@ class AtariEnvFactory(EnvFactoryRegistered): task=task, train_seed=train_seed, test_seed=test_seed, - venv_type=VectorEnvType.SUBPROC_SHARED_MEM, + venv_type=venv_type, envpool_factory=envpool_factory, ) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index d0a3b8d..90f2799 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -76,7 +76,7 @@ class MujocoEnvFactory(EnvFactoryRegistered): train_seed: int, test_seed: int, obs_norm: bool = True, - venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM, + venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO, ) -> None: super().__init__( task=task, diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 319375f..6f67de3 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -14,6 +14,8 @@ These plots are saved in the log directory and displayed in the console. import os import sys +from collections.abc import Sequence +from typing import Literal import torch @@ -21,7 +23,6 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.evaluation.launcher import RegisteredExpLauncher from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult from tianshou.highlevel.config import SamplingConfig -from tianshou.highlevel.env import VectorEnvType from tianshou.highlevel.experiment import ( ExperimentConfig, PPOExperimentBuilder, @@ -70,9 +71,6 @@ def main( train_seed=sampling_config.train_seed, test_seed=sampling_config.test_seed, obs_norm=True, - venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT - if sys.platform == "darwin" - else VectorEnvType.SUBPROC_SHARED_MEM, ) hidden_sizes = (64, 64) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 04b18ee..c6c692f 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -1,4 +1,5 @@ import logging +import platform from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from enum import Enum @@ -66,13 +67,15 @@ class VectorEnvType(Enum): """Vectorized environment without parallelization; environments are processed sequentially""" SUBPROC = "subproc" """Parallelization based on `subprocess`""" - SUBPROC_SHARED_MEM = "shmem" + SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem" """Parallelization based on `subprocess` with shared memory""" SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork" """Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn` by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)""" RAY = "ray" """Parallelization based on the `ray` library""" + SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto" + """Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise""" def create_venv( self, @@ -83,10 +86,16 @@ class VectorEnvType(Enum): return DummyVectorEnv(factories) case VectorEnvType.SUBPROC: return SubprocVectorEnv(factories) - case VectorEnvType.SUBPROC_SHARED_MEM: + case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT: return SubprocVectorEnv(factories, share_memory=True) case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT: return SubprocVectorEnv(factories, share_memory=True, context="fork") + case VectorEnvType.SUBPROC_SHARED_MEM_AUTO: + if platform.system().lower() == "windows": + selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT + else: + selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT + return selected_venv_type.create_venv(factories) case VectorEnvType.RAY: return RayVectorEnv(factories) case _: