Add convenient construction mechanisms for Environments

(based on factory function for a single environment)
This commit is contained in:
Dominik Jain 2023-10-24 15:38:16 +02:00 committed by Dominik Jain
parent dd4a0eb430
commit 96298eafd8

View File

@ -1,11 +1,17 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any, TypeAlias
from typing import Any, TypeAlias, cast
import gymnasium as gym
from tianshou.env import BaseVectorEnv
from tianshou.env import (
BaseVectorEnv,
DummyVectorEnv,
RayVectorEnv,
ShmemVectorEnv,
SubprocVectorEnv,
)
from tianshou.highlevel.persistence import Persistence
from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin
@ -34,6 +40,30 @@ class EnvType(Enum):
raise AssertionError(f"{requiring_entity} requires discrete environments")
class VectorEnvType(Enum):
DUMMY = "dummy"
"""Vectorized environment without parallelization; environments are processed sequentially"""
SUBPROC = "subproc"
"""Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM = "shmem"
"""Parallelization based on `subprocess` with shared memory"""
RAY = "ray"
"""Parallelization based on the `ray` library"""
def create_venv(self, factories: list[Callable[[], gym.Env]]) -> BaseVectorEnv:
match self:
case VectorEnvType.DUMMY:
return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM:
return ShmemVectorEnv(factories)
case VectorEnvType.RAY:
return RayVectorEnv(factories)
case _:
raise NotImplementedError(self)
class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments."""
@ -43,6 +73,35 @@ class Environments(ToStringMixin, ABC):
self.test_envs = test_envs
self.persistence: Sequence[Persistence] = []
@staticmethod
def from_factory_and_type(
factory_fn: Callable[[], gym.Env],
env_type: EnvType,
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and
the type of environment (continuous/discrete).
:param factory_fn: the factory for a single environment instance
:param env_type: the type of environments created by `factory_fn`
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:return: the instance
"""
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
test_envs = venv_type.create_venv([factory_fn] * num_test_envs)
env = factory_fn()
match env_type:
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs)
case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs)
case _:
raise ValueError(f"Environment type {env_type} not handled")
def _tostring_includes(self) -> list[str]:
return []
@ -89,6 +148,32 @@ class ContinuousEnvironments(Environments):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
@staticmethod
def from_factory(
factory_fn: Callable[[], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance.
:param factory_fn: the factory for a single environment instance
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:return: the instance
"""
return cast(
ContinuousEnvironments,
Environments.from_factory_and_type(
factory_fn,
EnvType.CONTINUOUS,
venv_type,
num_training_envs,
num_test_envs,
),
)
def info(self) -> dict[str, Any]:
d = super().info()
d["max_action"] = self.max_action
@ -128,6 +213,32 @@ class DiscreteEnvironments(Environments):
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
@staticmethod
def from_factory(
factory_fn: Callable[[], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance.
:param factory_fn: the factory for a single environment instance
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:return: the instance
"""
return cast(
DiscreteEnvironments,
Environments.from_factory_and_type(
factory_fn,
EnvType.CONTINUOUS,
venv_type,
num_training_envs,
num_test_envs,
),
)
def get_action_shape(self) -> TActionShape:
return self.action_shape