Added and used new VenvType: SUBPROC_SHARED_MEM_AUTO
This commit is contained in:
parent
d58ae163f2
commit
1cd22f1d32
@ -397,6 +397,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
|
||||
frame_stack: int,
|
||||
scale: bool = False,
|
||||
use_envpool_if_available: bool = True,
|
||||
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
|
||||
) -> None:
|
||||
assert "NoFrameskip" in task
|
||||
self.frame_stack = frame_stack
|
||||
@ -412,7 +413,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
|
||||
task=task,
|
||||
train_seed=train_seed,
|
||||
test_seed=test_seed,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
venv_type=venv_type,
|
||||
envpool_factory=envpool_factory,
|
||||
)
|
||||
|
||||
|
@ -76,7 +76,7 @@ class MujocoEnvFactory(EnvFactoryRegistered):
|
||||
train_seed: int,
|
||||
test_seed: int,
|
||||
obs_norm: bool = True,
|
||||
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
task=task,
|
||||
|
@ -14,6 +14,8 @@ These plots are saved in the log directory and displayed in the console.
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
@ -21,7 +23,6 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.evaluation.launcher import RegisteredExpLauncher
|
||||
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import VectorEnvType
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
@ -70,9 +71,6 @@ def main(
|
||||
train_seed=sampling_config.train_seed,
|
||||
test_seed=sampling_config.test_seed,
|
||||
obs_norm=True,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
||||
if sys.platform == "darwin"
|
||||
else VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
)
|
||||
|
||||
hidden_sizes = (64, 64)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import platform
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from enum import Enum
|
||||
@ -66,13 +67,15 @@ class VectorEnvType(Enum):
|
||||
"""Vectorized environment without parallelization; environments are processed sequentially"""
|
||||
SUBPROC = "subproc"
|
||||
"""Parallelization based on `subprocess`"""
|
||||
SUBPROC_SHARED_MEM = "shmem"
|
||||
SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem"
|
||||
"""Parallelization based on `subprocess` with shared memory"""
|
||||
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
|
||||
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
|
||||
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
|
||||
RAY = "ray"
|
||||
"""Parallelization based on the `ray` library"""
|
||||
SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto"
|
||||
"""Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise"""
|
||||
|
||||
def create_venv(
|
||||
self,
|
||||
@ -83,10 +86,16 @@ class VectorEnvType(Enum):
|
||||
return DummyVectorEnv(factories)
|
||||
case VectorEnvType.SUBPROC:
|
||||
return SubprocVectorEnv(factories)
|
||||
case VectorEnvType.SUBPROC_SHARED_MEM:
|
||||
case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT:
|
||||
return SubprocVectorEnv(factories, share_memory=True)
|
||||
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
|
||||
return SubprocVectorEnv(factories, share_memory=True, context="fork")
|
||||
case VectorEnvType.SUBPROC_SHARED_MEM_AUTO:
|
||||
if platform.system().lower() == "windows":
|
||||
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT
|
||||
else:
|
||||
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
||||
return selected_venv_type.create_venv(factories)
|
||||
case VectorEnvType.RAY:
|
||||
return RayVectorEnv(factories)
|
||||
case _:
|
||||
|
Loading…
x
Reference in New Issue
Block a user