2023-09-19 18:53:11 +02:00
|
|
|
from abc import ABC, abstractmethod
|
2023-09-20 09:29:34 +02:00
|
|
|
from collections.abc import Sequence
|
2023-09-21 12:36:27 +02:00
|
|
|
from enum import Enum
|
2023-09-20 09:29:34 +02:00
|
|
|
from typing import Any
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
import gymnasium as gym
|
|
|
|
|
|
|
|
from tianshou.env import BaseVectorEnv
|
2023-09-27 14:10:45 +02:00
|
|
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
2023-09-19 18:53:11 +02:00
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
TShape = int | Sequence[int]
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
|
2023-09-21 12:36:27 +02:00
|
|
|
class EnvType(Enum):
|
|
|
|
CONTINUOUS = "continuous"
|
|
|
|
DISCRETE = "discrete"
|
|
|
|
|
|
|
|
def is_discrete(self):
|
|
|
|
return self == EnvType.DISCRETE
|
|
|
|
|
|
|
|
def is_continuous(self):
|
|
|
|
return self == EnvType.CONTINUOUS
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
def assert_continuous(self, requiring_entity: Any):
|
|
|
|
if not self.is_continuous():
|
|
|
|
raise AssertionError(f"{requiring_entity} requires continuous environments")
|
|
|
|
|
|
|
|
def assert_discrete(self, requiring_entity: Any):
|
|
|
|
if not self.is_discrete():
|
|
|
|
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
|
|
|
|
2023-09-21 12:36:27 +02:00
|
|
|
|
2023-09-19 18:53:11 +02:00
|
|
|
class Environments(ABC):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
2023-09-19 18:53:11 +02:00
|
|
|
self.env = env
|
|
|
|
self.train_envs = train_envs
|
|
|
|
self.test_envs = test_envs
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
def info(self) -> dict[str, Any]:
|
2023-09-26 15:35:18 +02:00
|
|
|
return {
|
|
|
|
"action_shape": self.get_action_shape(),
|
|
|
|
"state_shape": self.get_observation_shape(),
|
|
|
|
}
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_action_shape(self) -> TShape:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-09-25 17:56:37 +02:00
|
|
|
def get_observation_shape(self) -> TShape:
|
2023-09-19 18:53:11 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
def get_action_space(self) -> gym.Space:
|
|
|
|
return self.env.action_space
|
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
def get_observation_space(self) -> gym.Space:
|
|
|
|
return self.env.observation_space
|
|
|
|
|
2023-09-21 12:36:27 +02:00
|
|
|
@abstractmethod
|
|
|
|
def get_type(self) -> EnvType:
|
|
|
|
pass
|
|
|
|
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
class ContinuousEnvironments(Environments):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
2023-09-19 18:53:11 +02:00
|
|
|
super().__init__(env, train_envs, test_envs)
|
|
|
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
|
|
|
|
|
|
|
def info(self):
|
|
|
|
d = super().info()
|
|
|
|
d["max_action"] = self.max_action
|
|
|
|
return d
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _get_continuous_env_info(
|
2023-09-20 09:29:34 +02:00
|
|
|
env: gym.Env,
|
|
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
|
2023-09-19 18:53:11 +02:00
|
|
|
if not isinstance(env.action_space, gym.spaces.Box):
|
|
|
|
raise ValueError(
|
|
|
|
"Only environments with continuous action space are supported here. "
|
2023-09-20 09:29:34 +02:00
|
|
|
f"But got env with action space: {env.action_space.__class__}.",
|
2023-09-19 18:53:11 +02:00
|
|
|
)
|
|
|
|
state_shape = env.observation_space.shape or env.observation_space.n
|
|
|
|
if not state_shape:
|
|
|
|
raise ValueError("Observation space shape is not defined")
|
|
|
|
action_shape = env.action_space.shape
|
|
|
|
max_action = env.action_space.high[0]
|
|
|
|
return state_shape, action_shape, max_action
|
|
|
|
|
|
|
|
def get_action_shape(self) -> TShape:
|
|
|
|
return self.action_shape
|
|
|
|
|
2023-09-25 17:56:37 +02:00
|
|
|
def get_observation_shape(self) -> TShape:
|
2023-09-19 18:53:11 +02:00
|
|
|
return self.state_shape
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
def get_type(self) -> EnvType:
|
2023-09-21 12:36:27 +02:00
|
|
|
return EnvType.CONTINUOUS
|
|
|
|
|
2023-09-19 18:53:11 +02:00
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
class DiscreteEnvironments(Environments):
|
|
|
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
|
|
|
super().__init__(env, train_envs, test_envs)
|
|
|
|
self.observation_shape = env.observation_space.shape or env.observation_space.n
|
|
|
|
self.action_shape = env.action_space.shape or env.action_space.n
|
|
|
|
|
|
|
|
def get_action_shape(self) -> TShape:
|
|
|
|
return self.action_shape
|
|
|
|
|
|
|
|
def get_observation_shape(self) -> TShape:
|
|
|
|
return self.observation_shape
|
|
|
|
|
|
|
|
def get_type(self) -> EnvType:
|
|
|
|
return EnvType.DISCRETE
|
|
|
|
|
|
|
|
|
2023-09-19 18:53:11 +02:00
|
|
|
class EnvFactory(ABC):
|
|
|
|
@abstractmethod
|
2023-09-27 14:10:45 +02:00
|
|
|
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
2023-09-20 09:29:34 +02:00
|
|
|
pass
|
2023-09-27 14:10:45 +02:00
|
|
|
|
|
|
|
def __call__(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
|
|
|
return self.create_envs(config=config)
|