from abc import ABC, abstractmethod from typing import Tuple, Optional, Dict, Any, Union, Sequence import gymnasium as gym from tianshou.env import BaseVectorEnv TShape = Union[int, Sequence[int]] class Environments(ABC): def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs self.test_envs = test_envs def info(self) -> Dict[str, Any]: return { "action_shape": self.get_action_shape(), "state_shape": self.get_state_shape() } @abstractmethod def get_action_shape(self) -> TShape: pass @abstractmethod def get_state_shape(self) -> TShape: pass def get_action_space(self) -> gym.Space: return self.env.action_space class ContinuousEnvironments(Environments): def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) def info(self): d = super().info() d["max_action"] = self.max_action return d @staticmethod def _get_continuous_env_info( env: gym.Env, ) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]: if not isinstance(env.action_space, gym.spaces.Box): raise ValueError( "Only environments with continuous action space are supported here. " f"But got env with action space: {env.action_space.__class__}." ) state_shape = env.observation_space.shape or env.observation_space.n if not state_shape: raise ValueError("Observation space shape is not defined") action_shape = env.action_space.shape max_action = env.action_space.high[0] return state_shape, action_shape, max_action def get_action_shape(self) -> TShape: return self.action_shape def get_state_shape(self) -> TShape: return self.state_shape class EnvFactory(ABC): @abstractmethod def create_envs(self) -> Environments: pass