Added and used new VenvType: SUBPROC_SHARED_MEM_AUTO
This commit is contained in:
parent
d58ae163f2
commit
1cd22f1d32
@ -397,6 +397,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
|
|||||||
frame_stack: int,
|
frame_stack: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
use_envpool_if_available: bool = True,
|
use_envpool_if_available: bool = True,
|
||||||
|
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert "NoFrameskip" in task
|
assert "NoFrameskip" in task
|
||||||
self.frame_stack = frame_stack
|
self.frame_stack = frame_stack
|
||||||
@ -412,7 +413,7 @@ class AtariEnvFactory(EnvFactoryRegistered):
|
|||||||
task=task,
|
task=task,
|
||||||
train_seed=train_seed,
|
train_seed=train_seed,
|
||||||
test_seed=test_seed,
|
test_seed=test_seed,
|
||||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
venv_type=venv_type,
|
||||||
envpool_factory=envpool_factory,
|
envpool_factory=envpool_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ class MujocoEnvFactory(EnvFactoryRegistered):
|
|||||||
train_seed: int,
|
train_seed: int,
|
||||||
test_seed: int,
|
test_seed: int,
|
||||||
obs_norm: bool = True,
|
obs_norm: bool = True,
|
||||||
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
|
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task=task,
|
task=task,
|
||||||
|
@ -14,6 +14,8 @@ These plots are saved in the log directory and displayed in the console.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -21,7 +23,6 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
|
|||||||
from tianshou.evaluation.launcher import RegisteredExpLauncher
|
from tianshou.evaluation.launcher import RegisteredExpLauncher
|
||||||
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
|
from tianshou.evaluation.rliable_evaluation_hl import RLiableExperimentResult
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import VectorEnvType
|
|
||||||
from tianshou.highlevel.experiment import (
|
from tianshou.highlevel.experiment import (
|
||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
@ -70,9 +71,6 @@ def main(
|
|||||||
train_seed=sampling_config.train_seed,
|
train_seed=sampling_config.train_seed,
|
||||||
test_seed=sampling_config.test_seed,
|
test_seed=sampling_config.test_seed,
|
||||||
obs_norm=True,
|
obs_norm=True,
|
||||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
|
||||||
if sys.platform == "darwin"
|
|
||||||
else VectorEnvType.SUBPROC_SHARED_MEM,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_sizes = (64, 64)
|
hidden_sizes = (64, 64)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import platform
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -66,13 +67,15 @@ class VectorEnvType(Enum):
|
|||||||
"""Vectorized environment without parallelization; environments are processed sequentially"""
|
"""Vectorized environment without parallelization; environments are processed sequentially"""
|
||||||
SUBPROC = "subproc"
|
SUBPROC = "subproc"
|
||||||
"""Parallelization based on `subprocess`"""
|
"""Parallelization based on `subprocess`"""
|
||||||
SUBPROC_SHARED_MEM = "shmem"
|
SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem"
|
||||||
"""Parallelization based on `subprocess` with shared memory"""
|
"""Parallelization based on `subprocess` with shared memory"""
|
||||||
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
|
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
|
||||||
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
|
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
|
||||||
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
|
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
|
||||||
RAY = "ray"
|
RAY = "ray"
|
||||||
"""Parallelization based on the `ray` library"""
|
"""Parallelization based on the `ray` library"""
|
||||||
|
SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto"
|
||||||
|
"""Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise"""
|
||||||
|
|
||||||
def create_venv(
|
def create_venv(
|
||||||
self,
|
self,
|
||||||
@ -83,10 +86,16 @@ class VectorEnvType(Enum):
|
|||||||
return DummyVectorEnv(factories)
|
return DummyVectorEnv(factories)
|
||||||
case VectorEnvType.SUBPROC:
|
case VectorEnvType.SUBPROC:
|
||||||
return SubprocVectorEnv(factories)
|
return SubprocVectorEnv(factories)
|
||||||
case VectorEnvType.SUBPROC_SHARED_MEM:
|
case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT:
|
||||||
return SubprocVectorEnv(factories, share_memory=True)
|
return SubprocVectorEnv(factories, share_memory=True)
|
||||||
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
|
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
|
||||||
return SubprocVectorEnv(factories, share_memory=True, context="fork")
|
return SubprocVectorEnv(factories, share_memory=True, context="fork")
|
||||||
|
case VectorEnvType.SUBPROC_SHARED_MEM_AUTO:
|
||||||
|
if platform.system().lower() == "windows":
|
||||||
|
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT
|
||||||
|
else:
|
||||||
|
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
|
||||||
|
return selected_venv_type.create_venv(factories)
|
||||||
case VectorEnvType.RAY:
|
case VectorEnvType.RAY:
|
||||||
return RayVectorEnv(factories)
|
return RayVectorEnv(factories)
|
||||||
case _:
|
case _:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user