Add convenient construction mechanisms for Environments
(based on factory function for a single environment)
This commit is contained in:
parent
dd4a0eb430
commit
96298eafd8
@ -1,11 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.env import (
|
||||
BaseVectorEnv,
|
||||
DummyVectorEnv,
|
||||
RayVectorEnv,
|
||||
ShmemVectorEnv,
|
||||
SubprocVectorEnv,
|
||||
)
|
||||
from tianshou.highlevel.persistence import Persistence
|
||||
from tianshou.utils.net.common import TActionShape
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
@ -34,6 +40,30 @@ class EnvType(Enum):
|
||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||
|
||||
|
||||
class VectorEnvType(Enum):
|
||||
DUMMY = "dummy"
|
||||
"""Vectorized environment without parallelization; environments are processed sequentially"""
|
||||
SUBPROC = "subproc"
|
||||
"""Parallelization based on `subprocess`"""
|
||||
SUBPROC_SHARED_MEM = "shmem"
|
||||
"""Parallelization based on `subprocess` with shared memory"""
|
||||
RAY = "ray"
|
||||
"""Parallelization based on the `ray` library"""
|
||||
|
||||
def create_venv(self, factories: list[Callable[[], gym.Env]]) -> BaseVectorEnv:
|
||||
match self:
|
||||
case VectorEnvType.DUMMY:
|
||||
return DummyVectorEnv(factories)
|
||||
case VectorEnvType.SUBPROC:
|
||||
return SubprocVectorEnv(factories)
|
||||
case VectorEnvType.SUBPROC_SHARED_MEM:
|
||||
return ShmemVectorEnv(factories)
|
||||
case VectorEnvType.RAY:
|
||||
return RayVectorEnv(factories)
|
||||
case _:
|
||||
raise NotImplementedError(self)
|
||||
|
||||
|
||||
class Environments(ToStringMixin, ABC):
|
||||
"""Represents (vectorized) environments."""
|
||||
|
||||
@ -43,6 +73,35 @@ class Environments(ToStringMixin, ABC):
|
||||
self.test_envs = test_envs
|
||||
self.persistence: Sequence[Persistence] = []
|
||||
|
||||
@staticmethod
|
||||
def from_factory_and_type(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
env_type: EnvType,
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
) -> "Environments":
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and
|
||||
the type of environment (continuous/discrete).
|
||||
|
||||
:param factory_fn: the factory for a single environment instance
|
||||
:param env_type: the type of environments created by `factory_fn`
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:return: the instance
|
||||
"""
|
||||
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
||||
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
|
||||
env = factory_fn()
|
||||
match env_type:
|
||||
case EnvType.CONTINUOUS:
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
case EnvType.DISCRETE:
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
case _:
|
||||
raise ValueError(f"Environment type {env_type} not handled")
|
||||
|
||||
def _tostring_includes(self) -> list[str]:
|
||||
return []
|
||||
|
||||
@ -89,6 +148,32 @@ class ContinuousEnvironments(Environments):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||
|
||||
@staticmethod
|
||||
def from_factory(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
) -> "ContinuousEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
:param factory_fn: the factory for a single environment instance
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
ContinuousEnvironments,
|
||||
Environments.from_factory_and_type(
|
||||
factory_fn,
|
||||
EnvType.CONTINUOUS,
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
),
|
||||
)
|
||||
|
||||
def info(self) -> dict[str, Any]:
|
||||
d = super().info()
|
||||
d["max_action"] = self.max_action
|
||||
@ -128,6 +213,32 @@ class DiscreteEnvironments(Environments):
|
||||
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
||||
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def from_factory(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
) -> "DiscreteEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
:param factory_fn: the factory for a single environment instance
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
DiscreteEnvironments,
|
||||
Environments.from_factory_and_type(
|
||||
factory_fn,
|
||||
EnvType.CONTINUOUS,
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
),
|
||||
)
|
||||
|
||||
def get_action_shape(self) -> TActionShape:
|
||||
return self.action_shape
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user