of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction
158 lines
5.0 KiB
Python
158 lines
5.0 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Sequence
|
|
from enum import Enum
|
|
from typing import Any, TypeAlias
|
|
|
|
import gymnasium as gym
|
|
|
|
from tianshou.env import BaseVectorEnv
|
|
from tianshou.highlevel.persistence import PersistableConfigProtocol, Persistence
|
|
from tianshou.utils.net.common import TActionShape
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
TObservationShape: TypeAlias = int | Sequence[int]
|
|
|
|
|
|
class EnvType(Enum):
|
|
"""Enumeration of environment types."""
|
|
|
|
CONTINUOUS = "continuous"
|
|
DISCRETE = "discrete"
|
|
|
|
def is_discrete(self) -> bool:
|
|
return self == EnvType.DISCRETE
|
|
|
|
def is_continuous(self) -> bool:
|
|
return self == EnvType.CONTINUOUS
|
|
|
|
def assert_continuous(self, requiring_entity: Any) -> None:
|
|
if not self.is_continuous():
|
|
raise AssertionError(f"{requiring_entity} requires continuous environments")
|
|
|
|
def assert_discrete(self, requiring_entity: Any) -> None:
|
|
if not self.is_discrete():
|
|
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
|
|
|
|
|
class Environments(ToStringMixin, ABC):
|
|
"""Represents (vectorized) environments."""
|
|
|
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
|
self.env = env
|
|
self.train_envs = train_envs
|
|
self.test_envs = test_envs
|
|
self.persistence: Sequence[Persistence] = []
|
|
|
|
def _tostring_includes(self) -> list[str]:
|
|
return []
|
|
|
|
def _tostring_additional_entries(self) -> dict[str, Any]:
|
|
return self.info()
|
|
|
|
def info(self) -> dict[str, Any]:
|
|
return {
|
|
"action_shape": self.get_action_shape(),
|
|
"state_shape": self.get_observation_shape(),
|
|
}
|
|
|
|
def set_persistence(self, *p: Persistence) -> None:
|
|
"""Associates the given persistence handlers which may persist and restore
|
|
environment-specific information.
|
|
|
|
:param p: persistence handlers
|
|
"""
|
|
self.persistence = p
|
|
|
|
@abstractmethod
|
|
def get_action_shape(self) -> TActionShape:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_observation_shape(self) -> TObservationShape:
|
|
pass
|
|
|
|
def get_action_space(self) -> gym.Space:
|
|
return self.env.action_space
|
|
|
|
def get_observation_space(self) -> gym.Space:
|
|
return self.env.observation_space
|
|
|
|
@abstractmethod
|
|
def get_type(self) -> EnvType:
|
|
pass
|
|
|
|
|
|
class ContinuousEnvironments(Environments):
|
|
"""Represents (vectorized) continuous environments."""
|
|
|
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
|
super().__init__(env, train_envs, test_envs)
|
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
|
|
|
def info(self) -> dict[str, Any]:
|
|
d = super().info()
|
|
d["max_action"] = self.max_action
|
|
return d
|
|
|
|
@staticmethod
|
|
def _get_continuous_env_info(
|
|
env: gym.Env,
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
|
|
if not isinstance(env.action_space, gym.spaces.Box):
|
|
raise ValueError(
|
|
"Only environments with continuous action space are supported here. "
|
|
f"But got env with action space: {env.action_space.__class__}.",
|
|
)
|
|
state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
|
if not state_shape:
|
|
raise ValueError("Observation space shape is not defined")
|
|
action_shape = env.action_space.shape
|
|
max_action = env.action_space.high[0]
|
|
return state_shape, action_shape, max_action
|
|
|
|
def get_action_shape(self) -> TActionShape:
|
|
return self.action_shape
|
|
|
|
def get_observation_shape(self) -> TObservationShape:
|
|
return self.state_shape
|
|
|
|
def get_type(self) -> EnvType:
|
|
return EnvType.CONTINUOUS
|
|
|
|
|
|
class DiscreteEnvironments(Environments):
|
|
"""Represents (vectorized) discrete environments."""
|
|
|
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
|
super().__init__(env, train_envs, test_envs)
|
|
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
|
|
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
|
|
|
|
def get_action_shape(self) -> TActionShape:
|
|
return self.action_shape
|
|
|
|
def get_observation_shape(self) -> TObservationShape:
|
|
return self.observation_shape
|
|
|
|
def get_type(self) -> EnvType:
|
|
return EnvType.DISCRETE
|
|
|
|
|
|
class EnvFactory(ToStringMixin, ABC):
|
|
@abstractmethod
|
|
def create_envs(
|
|
self,
|
|
num_training_envs: int,
|
|
num_test_envs: int,
|
|
config: PersistableConfigProtocol | None = None,
|
|
) -> Environments:
|
|
pass
|
|
|
|
def __call__(
|
|
self,
|
|
num_training_envs: int,
|
|
num_test_envs: int,
|
|
config: PersistableConfigProtocol | None = None,
|
|
) -> Environments:
|
|
return self.create_envs(num_training_envs, num_test_envs, config=config)
|