2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from abc import ABC, abstractmethod
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from collections.abc import Callable, Sequence
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from enum import Enum
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from typing import Any, TypeAlias, cast
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import gymnasium as gym
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from tianshou.env import (
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    BaseVectorEnv,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    DummyVectorEnv,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    RayVectorEnv,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    ShmemVectorEnv,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    SubprocVectorEnv,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								)
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 12:07:23 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from tianshou.highlevel.persistence import Persistence
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from tianshou.utils.net.common import TActionShape
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-07 10:54:22 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from tianshou.utils.string import ToStringMixin
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								TObservationShape: TypeAlias = int | Sequence[int]
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								class EnvType(Enum):
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-16 18:19:31 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    """Enumeration of environment types."""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    CONTINUOUS = "continuous"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    DISCRETE = "discrete"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def is_discrete(self) -> bool:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return self == EnvType.DISCRETE
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def is_continuous(self) -> bool:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return self == EnvType.CONTINUOUS
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def assert_continuous(self, requiring_entity: Any) -> None:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        if not self.is_continuous():
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            raise AssertionError(f"{requiring_entity} requires continuous environments")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def assert_discrete(self, requiring_entity: Any) -> None:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        if not self.is_discrete():
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            raise AssertionError(f"{requiring_entity} requires discrete environments")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								class VectorEnvType(Enum):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    DUMMY = "dummy"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """Vectorized environment without parallelization; environments are processed sequentially"""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    SUBPROC = "subproc"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """Parallelization based on `subprocess`"""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    SUBPROC_SHARED_MEM = "shmem"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """Parallelization based on `subprocess` with shared memory"""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    RAY = "ray"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """Parallelization based on the `ray` library"""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def create_venv(self, factories: list[Callable[[], gym.Env]]) -> BaseVectorEnv:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        match self:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case VectorEnvType.DUMMY:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                return DummyVectorEnv(factories)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case VectorEnvType.SUBPROC:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                return SubprocVectorEnv(factories)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case VectorEnvType.SUBPROC_SHARED_MEM:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                return ShmemVectorEnv(factories)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case VectorEnvType.RAY:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                return RayVectorEnv(factories)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case _:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                raise NotImplementedError(self)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-10 13:26:07 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								class Environments(ToStringMixin, ABC):
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-16 18:19:31 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    """Represents (vectorized) environments."""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.env = env
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.train_envs = train_envs
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.test_envs = test_envs
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-13 12:25:28 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        self.persistence: Sequence[Persistence] = []
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    @staticmethod
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def from_factory_and_type(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        factory_fn: Callable[[], gym.Env],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        env_type: EnvType,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        venv_type: VectorEnvType,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_training_envs: int,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_test_envs: int,
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        test_factory_fn: Callable[[], gym.Env] | None = None,
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    ) -> "Environments":
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-04 13:52:10 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param factory_fn: the factory for a single environment instance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param env_type: the type of environments created by `factory_fn`
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param venv_type: the vector environment type to use for parallelization
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param num_training_envs: the number of training environments to create
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param num_test_envs: the number of test environments to create
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        :param test_factory_fn: the factory to use for the creation of test environment instances;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            if None, use `factory_fn` for all environments (train and test)
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        :return: the instance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        if test_factory_fn is None:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            test_factory_fn = factory_fn
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        env = factory_fn()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        match env_type:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case EnvType.CONTINUOUS:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                return ContinuousEnvironments(env, train_envs, test_envs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case EnvType.DISCRETE:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                return DiscreteEnvironments(env, train_envs, test_envs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            case _:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                raise ValueError(f"Environment type {env_type} not handled")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-10 13:26:07 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def _tostring_includes(self) -> list[str]:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return []
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def _tostring_additional_entries(self) -> dict[str, Any]:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return self.info()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def info(self) -> dict[str, Any]:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "action_shape": self.get_action_shape(),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "state_shape": self.get_observation_shape(),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-13 12:25:28 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def set_persistence(self, *p: Persistence) -> None:
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-04 13:52:10 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        """Associates the given persistence handlers which may persist and restore environment-specific information.
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-16 18:19:31 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param p: persistence handlers
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-12 15:01:49 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        self.persistence = p
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    @abstractmethod
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_action_shape(self) -> TActionShape:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        pass
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    @abstractmethod
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_observation_shape(self) -> TObservationShape:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        pass
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def get_action_space(self) -> gym.Space:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return self.env.action_space
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-25 17:56:37 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_observation_space(self) -> gym.Space:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return self.env.observation_space
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    @abstractmethod
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def get_type(self) -> EnvType:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        pass
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								class ContinuousEnvironments(Environments):
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-16 18:19:31 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    """Represents (vectorized) continuous environments."""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        super().__init__(env, train_envs, test_envs)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    @staticmethod
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def from_factory(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        factory_fn: Callable[[], gym.Env],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        venv_type: VectorEnvType,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_training_envs: int,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_test_envs: int,
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        test_factory_fn: Callable[[], gym.Env] | None = None,
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    ) -> "ContinuousEnvironments":
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """Creates an instance from a factory function that creates a single instance.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param factory_fn: the factory for a single environment instance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param venv_type: the vector environment type to use for parallelization
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param num_training_envs: the number of training environments to create
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param num_test_envs: the number of test environments to create
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        :param test_factory_fn: the factory to use for the creation of test environment instances;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            if None, use `factory_fn` for all environments (train and test)
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        :return: the instance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return cast(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            ContinuousEnvironments,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            Environments.from_factory_and_type(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                factory_fn,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                EnvType.CONTINUOUS,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                venv_type,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                num_training_envs,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                num_test_envs,
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                test_factory_fn=test_factory_fn,
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            ),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def info(self) -> dict[str, Any]:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        d = super().info()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        d["max_action"] = self.max_action
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return d
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    @staticmethod
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def _get_continuous_env_info(
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        env: gym.Env,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    ) -> tuple[tuple[int, ...], tuple[int, ...], float]:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if not isinstance(env.action_space, gym.spaces.Box):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            raise ValueError(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                "Only environments with continuous action space are supported here. "
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                f"But got env with action space: {env.action_space.__class__}.",
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            )
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        state_shape = env.observation_space.shape or env.observation_space.n  # type: ignore
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if not state_shape:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            raise ValueError("Observation space shape is not defined")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        action_shape = env.action_space.shape
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        max_action = env.action_space.high[0]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return state_shape, action_shape, max_action
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_action_shape(self) -> TActionShape:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return self.action_shape
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_observation_shape(self) -> TObservationShape:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return self.state_shape
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-26 15:35:18 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_type(self) -> EnvType:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-21 12:36:27 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return EnvType.CONTINUOUS
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								class DiscreteEnvironments(Environments):
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-16 18:19:31 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    """Represents (vectorized) discrete environments."""
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        super().__init__(env, train_envs, test_envs)
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        self.observation_shape = env.observation_space.shape or env.observation_space.n  # type: ignore
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.action_shape = env.action_space.shape or env.action_space.n  # type: ignore
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    @staticmethod
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def from_factory(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        factory_fn: Callable[[], gym.Env],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        venv_type: VectorEnvType,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_training_envs: int,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        num_test_envs: int,
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        test_factory_fn: Callable[[], gym.Env] | None = None,
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    ) -> "DiscreteEnvironments":
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """Creates an instance from a factory function that creates a single instance.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param factory_fn: the factory for a single environment instance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param venv_type: the vector environment type to use for parallelization
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param num_training_envs: the number of training environments to create
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        :param num_test_envs: the number of test environments to create
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        :param test_factory_fn: the factory to use for the creation of test environment instances;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            if None, use `factory_fn` for all environments (train and test)
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        :return: the instance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return cast(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            DiscreteEnvironments,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            Environments.from_factory_and_type(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                factory_fn,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                EnvType.CONTINUOUS,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                venv_type,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                num_training_envs,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                num_test_envs,
							 | 
						
					
						
							
								
									
										
										
										
											2023-12-18 12:52:05 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                test_factory_fn=test_factory_fn,
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 15:38:16 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            ),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_action_shape(self) -> TActionShape:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return self.action_shape
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-09 17:22:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def get_observation_shape(self) -> TObservationShape:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-28 20:07:52 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return self.observation_shape
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def get_type(self) -> EnvType:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return EnvType.DISCRETE
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-17 12:05:36 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								class EnvFactory(ToStringMixin, ABC):
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    @abstractmethod
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-24 12:07:23 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        pass
							 |