71 lines
2.2 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, Any, Union, Sequence
import gymnasium as gym
from tianshou.env import BaseVectorEnv
TShape = Union[int, Sequence[int]]
class Environments(ABC):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
def info(self) -> Dict[str, Any]:
return {
"action_shape": self.get_action_shape(),
"state_shape": self.get_state_shape()
}
@abstractmethod
def get_action_shape(self) -> TShape:
pass
@abstractmethod
def get_state_shape(self) -> TShape:
pass
def get_action_space(self) -> gym.Space:
return self.env.action_space
class ContinuousEnvironments(Environments):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
def info(self):
d = super().info()
d["max_action"] = self.max_action
return d
@staticmethod
def _get_continuous_env_info(
env: gym.Env,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError(
"Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}."
)
state_shape = env.observation_space.shape or env.observation_space.n
if not state_shape:
raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape
max_action = env.action_space.high[0]
return state_shape, action_shape, max_action
def get_action_shape(self) -> TShape:
return self.action_shape
def get_state_shape(self) -> TShape:
return self.state_shape
class EnvFactory(ABC):
@abstractmethod
def create_envs(self) -> Environments:
pass