Improve environment factory abstractions in high-level API:
* EnvFactory now uses the creation of a single environment as
   the basic functionality which the more high-level functions build
   upon
 * Introduce enum EnvMode to indicate the purpose for which an env
   is created, allowing the factory creation process to change its
   behaviour accordingly
 * Add EnvFactoryGymnasium to provide direct support for envs that
   can be created via gymnasium.make
     - EnvPool is supported via an injectible EnvPoolFactory
     - Existing EnvFactory implementations are now derived from
       EnvFactoryGymnasium
 * Use a separate environment (which uses new EnvMode.WATCH) for
   watching agent performance after training (instead of using test
   environments, which the user may want to configure differently)
			
			
This commit is contained in:
		
							parent
							
								
									8188a904af
								
							
						
					
					
						commit
						eaab7b0a4b
					
				| @ -7,9 +7,15 @@ from collections import deque | ||||
| import cv2 | ||||
| import gymnasium as gym | ||||
| import numpy as np | ||||
| from gymnasium import Env | ||||
| 
 | ||||
| from tianshou.env import ShmemVectorEnv | ||||
| from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory | ||||
| from tianshou.highlevel.env import ( | ||||
|     EnvFactoryGymnasium, | ||||
|     EnvMode, | ||||
|     EnvPoolFactory, | ||||
|     VectorEnvType, | ||||
| ) | ||||
| from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext | ||||
| 
 | ||||
| try: | ||||
| @ -282,7 +288,7 @@ class FrameStack(gym.Wrapper): | ||||
| 
 | ||||
| 
 | ||||
| def wrap_deepmind( | ||||
|     env_id, | ||||
|     env: Env, | ||||
|     episode_life=True, | ||||
|     clip_rewards=True, | ||||
|     frame_stack=4, | ||||
| @ -293,7 +299,7 @@ def wrap_deepmind( | ||||
| 
 | ||||
|     The observation is channel-first: (c, h, w) instead of (h, w, c). | ||||
| 
 | ||||
|     :param str env_id: the atari environment id. | ||||
|     :param env: the Atari environment to wrap. | ||||
|     :param bool episode_life: wrap the episode life wrapper. | ||||
|     :param bool clip_rewards: wrap the reward clipping wrapper. | ||||
|     :param int frame_stack: wrap the frame stacking wrapper. | ||||
| @ -301,8 +307,6 @@ def wrap_deepmind( | ||||
|     :param bool warp_frame: wrap the grayscale + resize observation wrapper. | ||||
|     :return: the wrapped atari environment. | ||||
|     """ | ||||
|     assert "NoFrameskip" in env_id | ||||
|     env = gym.make(env_id) | ||||
|     env = NoopResetEnv(env, noop_max=30) | ||||
|     env = MaxAndSkipEnv(env, skip=4) | ||||
|     if episode_life: | ||||
| @ -351,19 +355,30 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): | ||||
|             stack_num=kwargs.get("frame_stack", 4), | ||||
|         ) | ||||
|     else: | ||||
|         assert "NoFrameskip" in task | ||||
|         warnings.warn( | ||||
|             "Recommend using envpool (pip install envpool) to run Atari games more efficiently.", | ||||
|         ) | ||||
|         env = wrap_deepmind(task, **kwargs) | ||||
|         train_envs = ShmemVectorEnv( | ||||
|             [ | ||||
|                 lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs) | ||||
|                 lambda: wrap_deepmind( | ||||
|                     gym.make(task), | ||||
|                     episode_life=True, | ||||
|                     clip_rewards=True, | ||||
|                     **kwargs, | ||||
|                 ) | ||||
|                 for _ in range(training_num) | ||||
|             ], | ||||
|         ) | ||||
|         test_envs = ShmemVectorEnv( | ||||
|             [ | ||||
|                 lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs) | ||||
|                 lambda: wrap_deepmind( | ||||
|                     gym.make(task), | ||||
|                     episode_life=False, | ||||
|                     clip_rewards=False, | ||||
|                     **kwargs, | ||||
|                 ) | ||||
|                 for _ in range(test_num) | ||||
|             ], | ||||
|         ) | ||||
| @ -373,23 +388,49 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): | ||||
|     return env, train_envs, test_envs | ||||
| 
 | ||||
| 
 | ||||
| class AtariEnvFactory(EnvFactory): | ||||
|     def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0): | ||||
|         self.task = task | ||||
|         self.seed = seed | ||||
| class AtariEnvFactory(EnvFactoryGymnasium): | ||||
|     def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False): | ||||
|         assert "NoFrameskip" in task | ||||
|         self.frame_stack = frame_stack | ||||
|         self.scale = scale | ||||
| 
 | ||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments: | ||||
|         env, train_envs, test_envs = make_atari_env( | ||||
|             task=self.task, | ||||
|             seed=self.seed, | ||||
|             training_num=num_training_envs, | ||||
|             test_num=num_test_envs, | ||||
|             scale=self.scale, | ||||
|             frame_stack=self.frame_stack, | ||||
|         super().__init__( | ||||
|             task=task, | ||||
|             seed=seed, | ||||
|             venv_type=VectorEnvType.SUBPROC_SHARED_MEM, | ||||
|             envpool_factory=self.EnvPoolFactory(self), | ||||
|         ) | ||||
|         return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) | ||||
| 
 | ||||
|     def create_env(self, mode: EnvMode) -> Env: | ||||
|         env = super().create_env(mode) | ||||
|         is_train = mode == EnvMode.TRAIN | ||||
|         return wrap_deepmind( | ||||
|             env, | ||||
|             episode_life=is_train, | ||||
|             clip_rewards=is_train, | ||||
|             frame_stack=self.frame_stack, | ||||
|             scale=self.scale, | ||||
|         ) | ||||
| 
 | ||||
|     class EnvPoolFactory(EnvPoolFactory): | ||||
|         def __init__(self, parent: "AtariEnvFactory"): | ||||
|             self.parent = parent | ||||
|             if self.parent.scale: | ||||
|                 warnings.warn( | ||||
|                     "EnvPool does not include ScaledFloatFrame wrapper, " | ||||
|                     "please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)", | ||||
|                 ) | ||||
| 
 | ||||
|         def _transform_task(self, task: str) -> str: | ||||
|             task = super()._transform_task(task) | ||||
|             return task.replace("NoFrameskip-v4", "-v5") | ||||
| 
 | ||||
|         def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: | ||||
|             kwargs = super()._transform_kwargs(kwargs, mode) | ||||
|             is_train = mode == EnvMode.TRAIN | ||||
|             kwargs["reward_clip"] = is_train | ||||
|             kwargs["episodic_life"] = is_train | ||||
|             kwargs["stack_num"] = self.parent.frame_stack | ||||
|             return kwargs | ||||
| 
 | ||||
| 
 | ||||
| class AtariStopCallback(TrainerStopCallback): | ||||
|  | ||||
| @ -1,11 +1,13 @@ | ||||
| import logging | ||||
| import pickle | ||||
| import warnings | ||||
| 
 | ||||
| import gymnasium as gym | ||||
| 
 | ||||
| from tianshou.env import ShmemVectorEnv, VectorEnvNormObs | ||||
| from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory | ||||
| from tianshou.env import VectorEnvNormObs | ||||
| from tianshou.highlevel.env import ( | ||||
|     ContinuousEnvironments, | ||||
|     EnvFactoryGymnasium, | ||||
|     EnvPoolFactory, | ||||
|     VectorEnvType, | ||||
| ) | ||||
| from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent | ||||
| from tianshou.highlevel.world import World | ||||
| 
 | ||||
| @ -24,25 +26,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in | ||||
| 
 | ||||
|     :return: a tuple of (single env, training envs, test envs). | ||||
|     """ | ||||
|     if envpool is not None: | ||||
|         train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed) | ||||
|         test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed) | ||||
|     else: | ||||
|         warnings.warn( | ||||
|             "Recommend using envpool (pip install envpool) " | ||||
|             "to run Mujoco environments more efficiently.", | ||||
|     envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs( | ||||
|         num_train_envs, | ||||
|         num_test_envs, | ||||
|     ) | ||||
|         env = gym.make(task) | ||||
|         train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) | ||||
|         test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) | ||||
|         train_envs.seed(seed) | ||||
|         test_envs.seed(seed) | ||||
|     if obs_norm: | ||||
|         # obs norm wrapper | ||||
|         train_envs = VectorEnvNormObs(train_envs) | ||||
|         test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) | ||||
|         test_envs.set_obs_rms(train_envs.get_obs_rms()) | ||||
|     return env, train_envs, test_envs | ||||
|     return envs.env, envs.train_envs, envs.test_envs | ||||
| 
 | ||||
| 
 | ||||
| class MujocoEnvObsRmsPersistence(Persistence): | ||||
| @ -68,21 +56,25 @@ class MujocoEnvObsRmsPersistence(Persistence): | ||||
|         world.envs.test_envs.set_obs_rms(obs_rms) | ||||
| 
 | ||||
| 
 | ||||
| class MujocoEnvFactory(EnvFactory): | ||||
| class MujocoEnvFactory(EnvFactoryGymnasium): | ||||
|     def __init__(self, task: str, seed: int, obs_norm=True): | ||||
|         self.task = task | ||||
|         self.seed = seed | ||||
|         super().__init__( | ||||
|             task=task, | ||||
|             seed=seed, | ||||
|             venv_type=VectorEnvType.SUBPROC_SHARED_MEM, | ||||
|             envpool_factory=EnvPoolFactory(), | ||||
|         ) | ||||
|         self.obs_norm = obs_norm | ||||
| 
 | ||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: | ||||
|         env, train_envs, test_envs = make_mujoco_env( | ||||
|             task=self.task, | ||||
|             seed=self.seed, | ||||
|             num_train_envs=num_training_envs, | ||||
|             num_test_envs=num_test_envs, | ||||
|             obs_norm=self.obs_norm, | ||||
|         ) | ||||
|         envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) | ||||
|         envs = super().create_envs(num_training_envs, num_test_envs) | ||||
|         assert isinstance(envs, ContinuousEnvironments) | ||||
| 
 | ||||
|         # obs norm wrapper | ||||
|         if self.obs_norm: | ||||
|             envs.train_envs = VectorEnvNormObs(envs.train_envs) | ||||
|             envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False) | ||||
|             envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) | ||||
|             envs.set_persistence(MujocoEnvObsRmsPersistence()) | ||||
| 
 | ||||
|         return envs | ||||
|  | ||||
| @ -1,27 +1,14 @@ | ||||
| import gymnasium as gym | ||||
| 
 | ||||
| from tianshou.env import DummyVectorEnv | ||||
| from tianshou.highlevel.env import ( | ||||
|     ContinuousEnvironments, | ||||
|     DiscreteEnvironments, | ||||
|     EnvFactory, | ||||
|     Environments, | ||||
|     EnvFactoryGymnasium, | ||||
|     VectorEnvType, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class DiscreteTestEnvFactory(EnvFactory): | ||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: | ||||
|         task = "CartPole-v0" | ||||
|         env = gym.make(task) | ||||
|         train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) | ||||
|         test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) | ||||
|         return DiscreteEnvironments(env, train_envs, test_envs) | ||||
| class DiscreteTestEnvFactory(EnvFactoryGymnasium): | ||||
|     def __init__(self): | ||||
|         super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY) | ||||
| 
 | ||||
| 
 | ||||
| class ContinuousTestEnvFactory(EnvFactory): | ||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: | ||||
|         task = "Pendulum-v1" | ||||
|         env = gym.make(task) | ||||
|         train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) | ||||
|         test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) | ||||
|         return ContinuousEnvironments(env, train_envs, test_envs) | ||||
| class ContinuousTestEnvFactory(EnvFactoryGymnasium): | ||||
|     def __init__(self): | ||||
|         super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY) | ||||
|  | ||||
| @ -1,9 +1,12 @@ | ||||
| import logging | ||||
| from abc import ABC, abstractmethod | ||||
| from collections.abc import Callable, Sequence | ||||
| from enum import Enum | ||||
| from typing import Any, TypeAlias, cast | ||||
| 
 | ||||
| import gymnasium as gym | ||||
| import gymnasium.spaces | ||||
| from gymnasium import Env | ||||
| 
 | ||||
| from tianshou.env import ( | ||||
|     BaseVectorEnv, | ||||
| @ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin | ||||
| 
 | ||||
| TObservationShape: TypeAlias = int | Sequence[int] | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class EnvType(Enum): | ||||
|     """Enumeration of environment types.""" | ||||
| @ -39,6 +44,23 @@ class EnvType(Enum): | ||||
|         if not self.is_discrete(): | ||||
|             raise AssertionError(f"{requiring_entity} requires discrete environments") | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_env(env: Env) -> "EnvType": | ||||
|         if isinstance(env.action_space, gymnasium.spaces.Discrete): | ||||
|             return EnvType.DISCRETE | ||||
|         elif isinstance(env.action_space, gymnasium.spaces.Box): | ||||
|             return EnvType.CONTINUOUS | ||||
|         else: | ||||
|             raise Exception(f"Unsupported environment type with action space {env.action_space}") | ||||
| 
 | ||||
| 
 | ||||
| class EnvMode(Enum): | ||||
|     """Indicates the purpose for which an environment is created.""" | ||||
| 
 | ||||
|     TRAIN = "train" | ||||
|     TEST = "test" | ||||
|     WATCH = "watch" | ||||
| 
 | ||||
| 
 | ||||
| class VectorEnvType(Enum): | ||||
|     DUMMY = "dummy" | ||||
| @ -65,7 +87,7 @@ class VectorEnvType(Enum): | ||||
| 
 | ||||
| 
 | ||||
| class Environments(ToStringMixin, ABC): | ||||
|     """Represents (vectorized) environments.""" | ||||
|     """Represents (vectorized) environments for a learning process.""" | ||||
| 
 | ||||
|     def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): | ||||
|         self.env = env | ||||
| @ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC): | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_factory_and_type( | ||||
|         factory_fn: Callable[[], gym.Env], | ||||
|         factory_fn: Callable[[EnvMode], gym.Env], | ||||
|         env_type: EnvType, | ||||
|         venv_type: VectorEnvType, | ||||
|         num_training_envs: int, | ||||
|         num_test_envs: int, | ||||
|         test_factory_fn: Callable[[], gym.Env] | None = None, | ||||
|     ) -> "Environments": | ||||
|         """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). | ||||
| 
 | ||||
| @ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC): | ||||
|         :param venv_type: the vector environment type to use for parallelization | ||||
|         :param num_training_envs: the number of training environments to create | ||||
|         :param num_test_envs: the number of test environments to create | ||||
|         :param test_factory_fn: the factory to use for the creation of test environment instances; | ||||
|             if None, use `factory_fn` for all environments (train and test) | ||||
|         :return: the instance | ||||
|         """ | ||||
|         if test_factory_fn is None: | ||||
|             test_factory_fn = factory_fn | ||||
|         train_envs = venv_type.create_venv([factory_fn] * num_training_envs) | ||||
|         test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs) | ||||
|         env = factory_fn() | ||||
|         train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) | ||||
|         test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) | ||||
|         env = factory_fn(EnvMode.TRAIN) | ||||
|         match env_type: | ||||
|             case EnvType.CONTINUOUS: | ||||
|                 return ContinuousEnvironments(env, train_envs, test_envs) | ||||
| @ -153,11 +170,10 @@ class ContinuousEnvironments(Environments): | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_factory( | ||||
|         factory_fn: Callable[[], gym.Env], | ||||
|         factory_fn: Callable[[EnvMode], gym.Env], | ||||
|         venv_type: VectorEnvType, | ||||
|         num_training_envs: int, | ||||
|         num_test_envs: int, | ||||
|         test_factory_fn: Callable[[], gym.Env] | None = None, | ||||
|     ) -> "ContinuousEnvironments": | ||||
|         """Creates an instance from a factory function that creates a single instance. | ||||
| 
 | ||||
| @ -165,8 +181,6 @@ class ContinuousEnvironments(Environments): | ||||
|         :param venv_type: the vector environment type to use for parallelization | ||||
|         :param num_training_envs: the number of training environments to create | ||||
|         :param num_test_envs: the number of test environments to create | ||||
|         :param test_factory_fn: the factory to use for the creation of test environment instances; | ||||
|             if None, use `factory_fn` for all environments (train and test) | ||||
|         :return: the instance | ||||
|         """ | ||||
|         return cast( | ||||
| @ -177,7 +191,6 @@ class ContinuousEnvironments(Environments): | ||||
|                 venv_type, | ||||
|                 num_training_envs, | ||||
|                 num_test_envs, | ||||
|                 test_factory_fn=test_factory_fn, | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
| @ -222,11 +235,10 @@ class DiscreteEnvironments(Environments): | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def from_factory( | ||||
|         factory_fn: Callable[[], gym.Env], | ||||
|         factory_fn: Callable[[EnvMode], gym.Env], | ||||
|         venv_type: VectorEnvType, | ||||
|         num_training_envs: int, | ||||
|         num_test_envs: int, | ||||
|         test_factory_fn: Callable[[], gym.Env] | None = None, | ||||
|     ) -> "DiscreteEnvironments": | ||||
|         """Creates an instance from a factory function that creates a single instance. | ||||
| 
 | ||||
| @ -234,8 +246,6 @@ class DiscreteEnvironments(Environments): | ||||
|         :param venv_type: the vector environment type to use for parallelization | ||||
|         :param num_training_envs: the number of training environments to create | ||||
|         :param num_test_envs: the number of test environments to create | ||||
|         :param test_factory_fn: the factory to use for the creation of test environment instances; | ||||
|             if None, use `factory_fn` for all environments (train and test) | ||||
|         :return: the instance | ||||
|         """ | ||||
|         return cast( | ||||
| @ -246,7 +256,6 @@ class DiscreteEnvironments(Environments): | ||||
|                 venv_type, | ||||
|                 num_training_envs, | ||||
|                 num_test_envs, | ||||
|                 test_factory_fn=test_factory_fn, | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
| @ -260,7 +269,156 @@ class DiscreteEnvironments(Environments): | ||||
|         return EnvType.DISCRETE | ||||
| 
 | ||||
| 
 | ||||
| class EnvPoolFactory: | ||||
|     """A factory for the creation of envpool-based vectorized environments.""" | ||||
| 
 | ||||
|     def _transform_task(self, task: str) -> str: | ||||
|         return task | ||||
| 
 | ||||
|     def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: | ||||
|         """Transforms gymnasium keyword arguments to be envpool-compatible. | ||||
| 
 | ||||
|         :param kwargs: keyword arguments that would normally be passed to `gymnasium.make`. | ||||
|         :param mode: the environment mode | ||||
|         :return: the transformed keyword arguments | ||||
|         """ | ||||
|         kwargs = dict(kwargs) | ||||
|         if "render_mode" in kwargs: | ||||
|             del kwargs["render_mode"] | ||||
|         return kwargs | ||||
| 
 | ||||
|     def create_venv( | ||||
|         self, | ||||
|         task: str, | ||||
|         num_envs: int, | ||||
|         mode: EnvMode, | ||||
|         seed: int, | ||||
|         kwargs: dict, | ||||
|     ) -> BaseVectorEnv | None: | ||||
|         try: | ||||
|             import envpool | ||||
| 
 | ||||
|             envpool_task = self._transform_task(task) | ||||
|             envpool_kwargs = self._transform_kwargs(kwargs, mode) | ||||
|             return envpool.make_gymnasium( | ||||
|                 envpool_task, | ||||
|                 num_envs=num_envs, | ||||
|                 seed=seed, | ||||
|                 **envpool_kwargs, | ||||
|             ) | ||||
|         except ImportError: | ||||
|             return None | ||||
| 
 | ||||
| 
 | ||||
| class EnvFactory(ToStringMixin, ABC): | ||||
|     def __init__(self, venv_type: VectorEnvType): | ||||
|         """:param venv_type: the type of vectorized environment to use""" | ||||
|         self.venv_type = venv_type | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: | ||||
|     def create_env(self, mode: EnvMode) -> Env: | ||||
|         pass | ||||
| 
 | ||||
|     def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: | ||||
|         """Create vectorized environments. | ||||
| 
 | ||||
|         :param num_envs: the number of environments | ||||
|         :param mode: the mode for which to create | ||||
|         :return: the vectorized environments | ||||
|         """ | ||||
|         return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) | ||||
| 
 | ||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: | ||||
|         """Create environments for learning. | ||||
| 
 | ||||
|         :param num_training_envs: the number of training environments | ||||
|         :param num_test_envs: the number of test environments | ||||
|         :return: the environments | ||||
|         """ | ||||
|         env = self.create_env(EnvMode.TRAIN) | ||||
|         train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) | ||||
|         test_envs = self.create_venv(num_test_envs, EnvMode.TEST) | ||||
|         match EnvType.from_env(env): | ||||
|             case EnvType.DISCRETE: | ||||
|                 return DiscreteEnvironments(env, train_envs, test_envs) | ||||
|             case EnvType.CONTINUOUS: | ||||
|                 return ContinuousEnvironments(env, train_envs, test_envs) | ||||
|             case _: | ||||
|                 raise ValueError | ||||
| 
 | ||||
| 
 | ||||
| class EnvFactoryGymnasium(EnvFactory): | ||||
|     """Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`).""" | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         *, | ||||
|         task: str, | ||||
|         seed: int, | ||||
|         venv_type: VectorEnvType, | ||||
|         envpool_factory: EnvPoolFactory | None = None, | ||||
|         render_mode_train: str | None = None, | ||||
|         render_mode_test: str | None = None, | ||||
|         render_mode_watch: str = "human", | ||||
|         **kwargs: Any, | ||||
|     ): | ||||
|         """:param task: the gymnasium task/environment identifier | ||||
|         :param seed: the random seed | ||||
|         :param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback. | ||||
|         :param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed. | ||||
|             If it is not installed, `venv_type` applies as a fallback. | ||||
|         :param render_mode_train: the render mode to use for training environments | ||||
|         :param render_mode_test: the render mode to use for test environments | ||||
|         :param render_mode_watch: the render mode to use for environments that are used to watch agent performance | ||||
|         :param kwargs: additional keyword arguments to pass on to `gymnasium.make`. | ||||
|             If envpool is used, the gymnasium parameters will be appropriately translated for use with | ||||
|             `envpool.make_gymnasium`. | ||||
|         """ | ||||
|         super().__init__(venv_type) | ||||
|         self.task = task | ||||
|         self.envpool_factory = envpool_factory | ||||
|         self.seed = seed | ||||
|         self.render_modes = { | ||||
|             EnvMode.TRAIN: render_mode_train, | ||||
|             EnvMode.TEST: render_mode_test, | ||||
|             EnvMode.WATCH: render_mode_watch, | ||||
|         } | ||||
|         self.kwargs = kwargs | ||||
| 
 | ||||
|     def _create_kwargs(self, mode: EnvMode) -> dict: | ||||
|         """Adapts the keyword arguments for the given mode. | ||||
| 
 | ||||
|         :param mode: the mode | ||||
|         :return: adapted keyword arguments | ||||
|         """ | ||||
|         kwargs = dict(self.kwargs) | ||||
|         kwargs["render_mode"] = self.render_modes.get(mode) | ||||
|         return kwargs | ||||
| 
 | ||||
|     def create_env(self, mode: EnvMode) -> Env: | ||||
|         """Creates a single environment for the given mode. | ||||
| 
 | ||||
|         :param mode: the mode | ||||
|         :return: an environment | ||||
|         """ | ||||
|         kwargs = self._create_kwargs(mode) | ||||
|         return gymnasium.make(self.task, **kwargs) | ||||
| 
 | ||||
|     def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: | ||||
|         if self.envpool_factory is not None: | ||||
|             venv = self.envpool_factory.create_venv( | ||||
|                 self.task, | ||||
|                 num_envs, | ||||
|                 mode, | ||||
|                 self.seed, | ||||
|                 self._create_kwargs(mode), | ||||
|             ) | ||||
|             if venv is not None: | ||||
|                 return venv | ||||
|             log.debug( | ||||
|                 f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}", | ||||
|             ) | ||||
| 
 | ||||
|         venv = super().create_venv(num_envs, mode) | ||||
|         venv.seed(self.seed) | ||||
|         return venv | ||||
|  | ||||
| @ -26,7 +26,7 @@ from tianshou.highlevel.agent import ( | ||||
|     TRPOAgentFactory, | ||||
| ) | ||||
| from tianshou.highlevel.config import SamplingConfig | ||||
| from tianshou.highlevel.env import EnvFactory | ||||
| from tianshou.highlevel.env import EnvFactory, EnvMode | ||||
| from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger | ||||
| from tianshou.highlevel.module.actor import ( | ||||
|     ActorFactory, | ||||
| @ -293,7 +293,7 @@ class Experiment(ToStringMixin): | ||||
|                 self._watch_agent( | ||||
|                     self.config.watch_num_episodes, | ||||
|                     policy, | ||||
|                     test_collector, | ||||
|                     self.env_factory, | ||||
|                     self.config.watch_render, | ||||
|                 ) | ||||
| 
 | ||||
| @ -303,15 +303,18 @@ class Experiment(ToStringMixin): | ||||
|     def _watch_agent( | ||||
|         num_episodes: int, | ||||
|         policy: BasePolicy, | ||||
|         test_collector: Collector, | ||||
|         env_factory: EnvFactory, | ||||
|         render: float, | ||||
|     ) -> None: | ||||
|         policy.eval() | ||||
|         test_collector.reset() | ||||
|         result = test_collector.collect(n_episode=num_episodes, render=render) | ||||
|         env = env_factory.create_env(EnvMode.WATCH) | ||||
|         collector = Collector(policy, env) | ||||
|         result = collector.collect(n_episode=num_episodes, render=render) | ||||
|         assert result.returns_stat is not None  # for mypy | ||||
|         assert result.lens_stat is not None  # for mypy | ||||
|         print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") | ||||
|         log.info( | ||||
|             f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}", | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class ExperimentBuilder: | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user