Add convenient construction mechanisms for Environments
(based on factory function for a single environment)
This commit is contained in:
parent
dd4a0eb430
commit
96298eafd8
@ -1,11 +1,17 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypeAlias
|
from typing import Any, TypeAlias, cast
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import (
|
||||||
|
BaseVectorEnv,
|
||||||
|
DummyVectorEnv,
|
||||||
|
RayVectorEnv,
|
||||||
|
ShmemVectorEnv,
|
||||||
|
SubprocVectorEnv,
|
||||||
|
)
|
||||||
from tianshou.highlevel.persistence import Persistence
|
from tianshou.highlevel.persistence import Persistence
|
||||||
from tianshou.utils.net.common import TActionShape
|
from tianshou.utils.net.common import TActionShape
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
@ -34,6 +40,30 @@ class EnvType(Enum):
|
|||||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorEnvType(Enum):
|
||||||
|
DUMMY = "dummy"
|
||||||
|
"""Vectorized environment without parallelization; environments are processed sequentially"""
|
||||||
|
SUBPROC = "subproc"
|
||||||
|
"""Parallelization based on `subprocess`"""
|
||||||
|
SUBPROC_SHARED_MEM = "shmem"
|
||||||
|
"""Parallelization based on `subprocess` with shared memory"""
|
||||||
|
RAY = "ray"
|
||||||
|
"""Parallelization based on the `ray` library"""
|
||||||
|
|
||||||
|
def create_venv(self, factories: list[Callable[[], gym.Env]]) -> BaseVectorEnv:
|
||||||
|
match self:
|
||||||
|
case VectorEnvType.DUMMY:
|
||||||
|
return DummyVectorEnv(factories)
|
||||||
|
case VectorEnvType.SUBPROC:
|
||||||
|
return SubprocVectorEnv(factories)
|
||||||
|
case VectorEnvType.SUBPROC_SHARED_MEM:
|
||||||
|
return ShmemVectorEnv(factories)
|
||||||
|
case VectorEnvType.RAY:
|
||||||
|
return RayVectorEnv(factories)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(self)
|
||||||
|
|
||||||
|
|
||||||
class Environments(ToStringMixin, ABC):
|
class Environments(ToStringMixin, ABC):
|
||||||
"""Represents (vectorized) environments."""
|
"""Represents (vectorized) environments."""
|
||||||
|
|
||||||
@ -43,6 +73,35 @@ class Environments(ToStringMixin, ABC):
|
|||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
self.persistence: Sequence[Persistence] = []
|
self.persistence: Sequence[Persistence] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_factory_and_type(
|
||||||
|
factory_fn: Callable[[], gym.Env],
|
||||||
|
env_type: EnvType,
|
||||||
|
venv_type: VectorEnvType,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
) -> "Environments":
|
||||||
|
"""Creates a suitable subtype instance from a factory function that creates a single instance and
|
||||||
|
the type of environment (continuous/discrete).
|
||||||
|
|
||||||
|
:param factory_fn: the factory for a single environment instance
|
||||||
|
:param env_type: the type of environments created by `factory_fn`
|
||||||
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
|
:param num_training_envs: the number of training environments to create
|
||||||
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:return: the instance
|
||||||
|
"""
|
||||||
|
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
||||||
|
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
|
||||||
|
env = factory_fn()
|
||||||
|
match env_type:
|
||||||
|
case EnvType.CONTINUOUS:
|
||||||
|
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||||
|
case EnvType.DISCRETE:
|
||||||
|
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Environment type {env_type} not handled")
|
||||||
|
|
||||||
def _tostring_includes(self) -> list[str]:
|
def _tostring_includes(self) -> list[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -89,6 +148,32 @@ class ContinuousEnvironments(Environments):
|
|||||||
super().__init__(env, train_envs, test_envs)
|
super().__init__(env, train_envs, test_envs)
|
||||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_factory(
|
||||||
|
factory_fn: Callable[[], gym.Env],
|
||||||
|
venv_type: VectorEnvType,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
) -> "ContinuousEnvironments":
|
||||||
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
|
:param factory_fn: the factory for a single environment instance
|
||||||
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
|
:param num_training_envs: the number of training environments to create
|
||||||
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:return: the instance
|
||||||
|
"""
|
||||||
|
return cast(
|
||||||
|
ContinuousEnvironments,
|
||||||
|
Environments.from_factory_and_type(
|
||||||
|
factory_fn,
|
||||||
|
EnvType.CONTINUOUS,
|
||||||
|
venv_type,
|
||||||
|
num_training_envs,
|
||||||
|
num_test_envs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def info(self) -> dict[str, Any]:
|
def info(self) -> dict[str, Any]:
|
||||||
d = super().info()
|
d = super().info()
|
||||||
d["max_action"] = self.max_action
|
d["max_action"] = self.max_action
|
||||||
@ -128,6 +213,32 @@ class DiscreteEnvironments(Environments):
|
|||||||
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||||
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_factory(
|
||||||
|
factory_fn: Callable[[], gym.Env],
|
||||||
|
venv_type: VectorEnvType,
|
||||||
|
num_training_envs: int,
|
||||||
|
num_test_envs: int,
|
||||||
|
) -> "DiscreteEnvironments":
|
||||||
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
|
:param factory_fn: the factory for a single environment instance
|
||||||
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
|
:param num_training_envs: the number of training environments to create
|
||||||
|
:param num_test_envs: the number of test environments to create
|
||||||
|
:return: the instance
|
||||||
|
"""
|
||||||
|
return cast(
|
||||||
|
DiscreteEnvironments,
|
||||||
|
Environments.from_factory_and_type(
|
||||||
|
factory_fn,
|
||||||
|
EnvType.CONTINUOUS,
|
||||||
|
venv_type,
|
||||||
|
num_training_envs,
|
||||||
|
num_test_envs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def get_action_shape(self) -> TActionShape:
|
def get_action_shape(self) -> TActionShape:
|
||||||
return self.action_shape
|
return self.action_shape
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user