Policy objects are now parametrised by converting the parameter dataclass instances to kwargs, using some injectable conversions along the way
92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Sequence
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
import gymnasium as gym
|
|
|
|
from tianshou.env import BaseVectorEnv
|
|
|
|
TShape = int | Sequence[int]
|
|
|
|
|
|
class EnvType(Enum):
|
|
CONTINUOUS = "continuous"
|
|
DISCRETE = "discrete"
|
|
|
|
def is_discrete(self):
|
|
return self == EnvType.DISCRETE
|
|
|
|
def is_continuous(self):
|
|
return self == EnvType.CONTINUOUS
|
|
|
|
|
|
class Environments(ABC):
|
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
|
self.env = env
|
|
self.train_envs = train_envs
|
|
self.test_envs = test_envs
|
|
|
|
def info(self) -> dict[str, Any]:
|
|
return {"action_shape": self.get_action_shape(), "state_shape": self.get_observation_shape()}
|
|
|
|
@abstractmethod
|
|
def get_action_shape(self) -> TShape:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_observation_shape(self) -> TShape:
|
|
pass
|
|
|
|
def get_action_space(self) -> gym.Space:
|
|
return self.env.action_space
|
|
|
|
def get_observation_space(self) -> gym.Space:
|
|
return self.env.observation_space
|
|
|
|
@abstractmethod
|
|
def get_type(self) -> EnvType:
|
|
pass
|
|
|
|
|
|
class ContinuousEnvironments(Environments):
|
|
def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
|
super().__init__(env, train_envs, test_envs)
|
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
|
|
|
def info(self):
|
|
d = super().info()
|
|
d["max_action"] = self.max_action
|
|
return d
|
|
|
|
@staticmethod
|
|
def _get_continuous_env_info(
|
|
env: gym.Env,
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
|
|
if not isinstance(env.action_space, gym.spaces.Box):
|
|
raise ValueError(
|
|
"Only environments with continuous action space are supported here. "
|
|
f"But got env with action space: {env.action_space.__class__}.",
|
|
)
|
|
state_shape = env.observation_space.shape or env.observation_space.n
|
|
if not state_shape:
|
|
raise ValueError("Observation space shape is not defined")
|
|
action_shape = env.action_space.shape
|
|
max_action = env.action_space.high[0]
|
|
return state_shape, action_shape, max_action
|
|
|
|
def get_action_shape(self) -> TShape:
|
|
return self.action_shape
|
|
|
|
def get_observation_shape(self) -> TShape:
|
|
return self.state_shape
|
|
|
|
def get_type(self):
|
|
return EnvType.CONTINUOUS
|
|
|
|
|
|
class EnvFactory(ABC):
|
|
@abstractmethod
|
|
def create_envs(self) -> Environments:
|
|
pass
|