92 lines
2.6 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
2023-09-20 09:29:34 +02:00
from collections.abc import Sequence
from enum import Enum
2023-09-20 09:29:34 +02:00
from typing import Any
import gymnasium as gym
from tianshou.env import BaseVectorEnv
2023-09-20 09:29:34 +02:00
TShape = int | Sequence[int]
class EnvType(Enum):
CONTINUOUS = "continuous"
DISCRETE = "discrete"
def is_discrete(self):
return self == EnvType.DISCRETE
def is_continuous(self):
return self == EnvType.CONTINUOUS
class Environments(ABC):
2023-09-20 09:29:34 +02:00
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
2023-09-20 09:29:34 +02:00
def info(self) -> dict[str, Any]:
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
@abstractmethod
def get_action_shape(self) -> TShape:
pass
@abstractmethod
def get_observation_shape(self) -> TShape:
pass
def get_action_space(self) -> gym.Space:
return self.env.action_space
def get_observation_space(self) -> gym.Space:
return self.env.observation_space
@abstractmethod
def get_type(self) -> EnvType:
pass
class ContinuousEnvironments(Environments):
2023-09-20 09:29:34 +02:00
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
def info(self):
d = super().info()
d["max_action"] = self.max_action
return d
@staticmethod
def _get_continuous_env_info(
2023-09-20 09:29:34 +02:00
env: gym.Env,
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError(
"Only environments with continuous action space are supported here. "
2023-09-20 09:29:34 +02:00
f"But got env with action space: {env.action_space.__class__}.",
)
state_shape = env.observation_space.shape or env.observation_space.n
if not state_shape:
raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape
max_action = env.action_space.high[0]
return state_shape, action_shape, max_action
def get_action_shape(self) -> TShape:
return self.action_shape
def get_observation_shape(self) -> TShape:
return self.state_shape
def get_type(self):
return EnvType.CONTINUOUS
class EnvFactory(ABC):
@abstractmethod
def create_envs(self) -> Environments:
2023-09-20 09:29:34 +02:00
pass