Added and used new VenvType: SUBPROC_SHARED_MEM_AUTO

This commit is contained in:
Michael Panchenko 2024-05-06 21:22:39 +02:00
parent d58ae163f2
commit 1cd22f1d32
4 changed files with 16 additions and 8 deletions

View File

@ -397,6 +397,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None:
assert "NoFrameskip" in task
self.frame_stack = frame_stack
@ -412,7 +413,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
task=task,
train_seed=train_seed,
test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
venv_type=venv_type,
envpool_factory=envpool_factory,
)

View File

@ -76,7 +76,7 @@ class MujocoEnvFactory(EnvFactoryRegistered):
train_seed: int,
test_seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None:
super().__init__(
task=task,

View File

@ -14,6 +14,8 @@ These plots are saved in the log directory and displayed in the console.
import os
import sys
from collections.abc import Sequence
from typing import Literal
import torch
@ -21,7 +23,6 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.evaluation.launcher import RegisteredExpLauncher
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
@ -70,9 +71,6 @@ def main(
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
if sys.platform == "darwin"
else VectorEnvType.SUBPROC_SHARED_MEM,
)
hidden_sizes = (64, 64)

View File

@ -1,4 +1,5 @@
import logging
import platform
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from enum import Enum
@ -66,13 +67,15 @@ class VectorEnvType(Enum):
"""Vectorized environment without parallelization; environments are processed sequentially"""
SUBPROC = "subproc"
"""Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM = "shmem"
SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem"
"""Parallelization based on `subprocess` with shared memory"""
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
RAY = "ray"
"""Parallelization based on the `ray` library"""
SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto"
"""Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise"""
def create_venv(
self,
@ -83,10 +86,16 @@ class VectorEnvType(Enum):
return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM:
case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True)
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True, context="fork")
case VectorEnvType.SUBPROC_SHARED_MEM_AUTO:
if platform.system().lower() == "windows":
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT
else:
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
return selected_venv_type.create_venv(factories)
case VectorEnvType.RAY:
return RayVectorEnv(factories)
case _: