Added and used new VenvType: SUBPROC_SHARED_MEM_AUTO

This commit is contained in:
Michael Panchenko 2024-05-06 21:22:39 +02:00
parent d58ae163f2
commit 1cd22f1d32
4 changed files with 16 additions and 8 deletions

View File

@ -397,6 +397,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
frame_stack: int, frame_stack: int,
scale: bool = False, scale: bool = False,
use_envpool_if_available: bool = True, use_envpool_if_available: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None: ) -> None:
assert "NoFrameskip" in task assert "NoFrameskip" in task
self.frame_stack = frame_stack self.frame_stack = frame_stack
@ -412,7 +413,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
task=task, task=task,
train_seed=train_seed, train_seed=train_seed,
test_seed=test_seed, test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM, venv_type=venv_type,
envpool_factory=envpool_factory, envpool_factory=envpool_factory,
) )

View File

@ -76,7 +76,7 @@ class MujocoEnvFactory(EnvFactoryRegistered):
train_seed: int, train_seed: int,
test_seed: int, test_seed: int,
obs_norm: bool = True, obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM, venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None: ) -> None:
super().__init__( super().__init__(
task=task, task=task,

View File

@ -14,6 +14,8 @@ These plots are saved in the log directory and displayed in the console.
import os import os
import sys import sys
from collections.abc import Sequence
from typing import Literal
import torch import torch
@ -21,7 +23,6 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.evaluation.launcher import RegisteredExpLauncher from tianshou.evaluation.launcher import RegisteredExpLauncher
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import VectorEnvType
from tianshou.highlevel.experiment import ( from tianshou.highlevel.experiment import (
ExperimentConfig, ExperimentConfig,
PPOExperimentBuilder, PPOExperimentBuilder,
@ -70,9 +71,6 @@ def main(
train_seed=sampling_config.train_seed, train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed, test_seed=sampling_config.test_seed,
obs_norm=True, obs_norm=True,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
if sys.platform == "darwin"
else VectorEnvType.SUBPROC_SHARED_MEM,
) )
hidden_sizes = (64, 64) hidden_sizes = (64, 64)

View File

@ -1,4 +1,5 @@
import logging import logging
import platform
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from enum import Enum from enum import Enum
@ -66,13 +67,15 @@ class VectorEnvType(Enum):
"""Vectorized environment without parallelization; environments are processed sequentially""" """Vectorized environment without parallelization; environments are processed sequentially"""
SUBPROC = "subproc" SUBPROC = "subproc"
"""Parallelization based on `subprocess`""" """Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM = "shmem" SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem"
"""Parallelization based on `subprocess` with shared memory""" """Parallelization based on `subprocess` with shared memory"""
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork" SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn` """Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)""" by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
RAY = "ray" RAY = "ray"
"""Parallelization based on the `ray` library""" """Parallelization based on the `ray` library"""
SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto"
"""Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise"""
def create_venv( def create_venv(
self, self,
@ -83,10 +86,16 @@ class VectorEnvType(Enum):
return DummyVectorEnv(factories) return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC: case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories) return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM: case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True) return SubprocVectorEnv(factories, share_memory=True)
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT: case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True, context="fork") return SubprocVectorEnv(factories, share_memory=True, context="fork")
case VectorEnvType.SUBPROC_SHARED_MEM_AUTO:
if platform.system().lower() == "windows":
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT
else:
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
return selected_venv_type.create_venv(factories)
case VectorEnvType.RAY: case VectorEnvType.RAY:
return RayVectorEnv(factories) return RayVectorEnv(factories)
case _: case _: