from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from enum import Enum from typing import Any, TypeAlias, cast import gymnasium as gym from tianshou.env import ( BaseVectorEnv, DummyVectorEnv, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv, ) from tianshou.highlevel.persistence import Persistence from tianshou.utils.net.common import TActionShape from tianshou.utils.string import ToStringMixin TObservationShape: TypeAlias = int | Sequence[int] class EnvType(Enum): """Enumeration of environment types.""" CONTINUOUS = "continuous" DISCRETE = "discrete" def is_discrete(self) -> bool: return self == EnvType.DISCRETE def is_continuous(self) -> bool: return self == EnvType.CONTINUOUS def assert_continuous(self, requiring_entity: Any) -> None: if not self.is_continuous(): raise AssertionError(f"{requiring_entity} requires continuous environments") def assert_discrete(self, requiring_entity: Any) -> None: if not self.is_discrete(): raise AssertionError(f"{requiring_entity} requires discrete environments") class VectorEnvType(Enum): DUMMY = "dummy" """Vectorized environment without parallelization; environments are processed sequentially""" SUBPROC = "subproc" """Parallelization based on `subprocess`""" SUBPROC_SHARED_MEM = "shmem" """Parallelization based on `subprocess` with shared memory""" RAY = "ray" """Parallelization based on the `ray` library""" def create_venv(self, factories: list[Callable[[], gym.Env]]) -> BaseVectorEnv: match self: case VectorEnvType.DUMMY: return DummyVectorEnv(factories) case VectorEnvType.SUBPROC: return SubprocVectorEnv(factories) case VectorEnvType.SUBPROC_SHARED_MEM: return ShmemVectorEnv(factories) case VectorEnvType.RAY: return RayVectorEnv(factories) case _: raise NotImplementedError(self) class Environments(ToStringMixin, ABC): """Represents (vectorized) environments.""" def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs self.test_envs = test_envs self.persistence: Sequence[Persistence] = [] @staticmethod def from_factory_and_type( factory_fn: Callable[[], gym.Env], env_type: EnvType, venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, ) -> "Environments": """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). :param factory_fn: the factory for a single environment instance :param env_type: the type of environments created by `factory_fn` :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :return: the instance """ train_envs = venv_type.create_venv([factory_fn] * num_training_envs) test_envs = venv_type.create_venv([factory_fn] * num_test_envs) env = factory_fn() match env_type: case EnvType.CONTINUOUS: return ContinuousEnvironments(env, train_envs, test_envs) case EnvType.DISCRETE: return DiscreteEnvironments(env, train_envs, test_envs) case _: raise ValueError(f"Environment type {env_type} not handled") def _tostring_includes(self) -> list[str]: return [] def _tostring_additional_entries(self) -> dict[str, Any]: return self.info() def info(self) -> dict[str, Any]: return { "action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape(), } def set_persistence(self, *p: Persistence) -> None: """Associates the given persistence handlers which may persist and restore environment-specific information. :param p: persistence handlers """ self.persistence = p @abstractmethod def get_action_shape(self) -> TActionShape: pass @abstractmethod def get_observation_shape(self) -> TObservationShape: pass def get_action_space(self) -> gym.Space: return self.env.action_space def get_observation_space(self) -> gym.Space: return self.env.observation_space @abstractmethod def get_type(self) -> EnvType: pass class ContinuousEnvironments(Environments): """Represents (vectorized) continuous environments.""" def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) @staticmethod def from_factory( factory_fn: Callable[[], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, ) -> "ContinuousEnvironments": """Creates an instance from a factory function that creates a single instance. :param factory_fn: the factory for a single environment instance :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :return: the instance """ return cast( ContinuousEnvironments, Environments.from_factory_and_type( factory_fn, EnvType.CONTINUOUS, venv_type, num_training_envs, num_test_envs, ), ) def info(self) -> dict[str, Any]: d = super().info() d["max_action"] = self.max_action return d @staticmethod def _get_continuous_env_info( env: gym.Env, ) -> tuple[tuple[int, ...], tuple[int, ...], float]: if not isinstance(env.action_space, gym.spaces.Box): raise ValueError( "Only environments with continuous action space are supported here. " f"But got env with action space: {env.action_space.__class__}.", ) state_shape = env.observation_space.shape or env.observation_space.n # type: ignore if not state_shape: raise ValueError("Observation space shape is not defined") action_shape = env.action_space.shape max_action = env.action_space.high[0] return state_shape, action_shape, max_action def get_action_shape(self) -> TActionShape: return self.action_shape def get_observation_shape(self) -> TObservationShape: return self.state_shape def get_type(self) -> EnvType: return EnvType.CONTINUOUS class DiscreteEnvironments(Environments): """Represents (vectorized) discrete environments.""" def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore self.action_shape = env.action_space.shape or env.action_space.n # type: ignore @staticmethod def from_factory( factory_fn: Callable[[], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, ) -> "DiscreteEnvironments": """Creates an instance from a factory function that creates a single instance. :param factory_fn: the factory for a single environment instance :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create :return: the instance """ return cast( DiscreteEnvironments, Environments.from_factory_and_type( factory_fn, EnvType.CONTINUOUS, venv_type, num_training_envs, num_test_envs, ), ) def get_action_shape(self) -> TActionShape: return self.action_shape def get_observation_shape(self) -> TObservationShape: return self.observation_shape def get_type(self) -> EnvType: return EnvType.DISCRETE class EnvFactory(ToStringMixin, ABC): @abstractmethod def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: pass