Improve environment factory abstractions in high-level API:
* EnvFactory now uses the creation of a single environment as
   the basic functionality which the more high-level functions build
   upon
 * Introduce enum EnvMode to indicate the purpose for which an env
   is created, allowing the factory creation process to change its
   behaviour accordingly
 * Add EnvFactoryGymnasium to provide direct support for envs that
   can be created via gymnasium.make
     - EnvPool is supported via an injectible EnvPoolFactory
     - Existing EnvFactory implementations are now derived from
       EnvFactoryGymnasium
 * Use a separate environment (which uses new EnvMode.WATCH) for
   watching agent performance after training (instead of using test
   environments, which the user may want to configure differently)
			
			
This commit is contained in:
		
							parent
							
								
									8188a904af
								
							
						
					
					
						commit
						eaab7b0a4b
					
				| @ -7,9 +7,15 @@ from collections import deque | |||||||
| import cv2 | import cv2 | ||||||
| import gymnasium as gym | import gymnasium as gym | ||||||
| import numpy as np | import numpy as np | ||||||
|  | from gymnasium import Env | ||||||
| 
 | 
 | ||||||
| from tianshou.env import ShmemVectorEnv | from tianshou.env import ShmemVectorEnv | ||||||
| from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory | from tianshou.highlevel.env import ( | ||||||
|  |     EnvFactoryGymnasium, | ||||||
|  |     EnvMode, | ||||||
|  |     EnvPoolFactory, | ||||||
|  |     VectorEnvType, | ||||||
|  | ) | ||||||
| from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext | from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
| @ -282,7 +288,7 @@ class FrameStack(gym.Wrapper): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def wrap_deepmind( | def wrap_deepmind( | ||||||
|     env_id, |     env: Env, | ||||||
|     episode_life=True, |     episode_life=True, | ||||||
|     clip_rewards=True, |     clip_rewards=True, | ||||||
|     frame_stack=4, |     frame_stack=4, | ||||||
| @ -293,7 +299,7 @@ def wrap_deepmind( | |||||||
| 
 | 
 | ||||||
|     The observation is channel-first: (c, h, w) instead of (h, w, c). |     The observation is channel-first: (c, h, w) instead of (h, w, c). | ||||||
| 
 | 
 | ||||||
|     :param str env_id: the atari environment id. |     :param env: the Atari environment to wrap. | ||||||
|     :param bool episode_life: wrap the episode life wrapper. |     :param bool episode_life: wrap the episode life wrapper. | ||||||
|     :param bool clip_rewards: wrap the reward clipping wrapper. |     :param bool clip_rewards: wrap the reward clipping wrapper. | ||||||
|     :param int frame_stack: wrap the frame stacking wrapper. |     :param int frame_stack: wrap the frame stacking wrapper. | ||||||
| @ -301,8 +307,6 @@ def wrap_deepmind( | |||||||
|     :param bool warp_frame: wrap the grayscale + resize observation wrapper. |     :param bool warp_frame: wrap the grayscale + resize observation wrapper. | ||||||
|     :return: the wrapped atari environment. |     :return: the wrapped atari environment. | ||||||
|     """ |     """ | ||||||
|     assert "NoFrameskip" in env_id |  | ||||||
|     env = gym.make(env_id) |  | ||||||
|     env = NoopResetEnv(env, noop_max=30) |     env = NoopResetEnv(env, noop_max=30) | ||||||
|     env = MaxAndSkipEnv(env, skip=4) |     env = MaxAndSkipEnv(env, skip=4) | ||||||
|     if episode_life: |     if episode_life: | ||||||
| @ -351,19 +355,30 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): | |||||||
|             stack_num=kwargs.get("frame_stack", 4), |             stack_num=kwargs.get("frame_stack", 4), | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|  |         assert "NoFrameskip" in task | ||||||
|         warnings.warn( |         warnings.warn( | ||||||
|             "Recommend using envpool (pip install envpool) to run Atari games more efficiently.", |             "Recommend using envpool (pip install envpool) to run Atari games more efficiently.", | ||||||
|         ) |         ) | ||||||
|         env = wrap_deepmind(task, **kwargs) |         env = wrap_deepmind(task, **kwargs) | ||||||
|         train_envs = ShmemVectorEnv( |         train_envs = ShmemVectorEnv( | ||||||
|             [ |             [ | ||||||
|                 lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs) |                 lambda: wrap_deepmind( | ||||||
|  |                     gym.make(task), | ||||||
|  |                     episode_life=True, | ||||||
|  |                     clip_rewards=True, | ||||||
|  |                     **kwargs, | ||||||
|  |                 ) | ||||||
|                 for _ in range(training_num) |                 for _ in range(training_num) | ||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
|         test_envs = ShmemVectorEnv( |         test_envs = ShmemVectorEnv( | ||||||
|             [ |             [ | ||||||
|                 lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs) |                 lambda: wrap_deepmind( | ||||||
|  |                     gym.make(task), | ||||||
|  |                     episode_life=False, | ||||||
|  |                     clip_rewards=False, | ||||||
|  |                     **kwargs, | ||||||
|  |                 ) | ||||||
|                 for _ in range(test_num) |                 for _ in range(test_num) | ||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
| @ -373,23 +388,49 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): | |||||||
|     return env, train_envs, test_envs |     return env, train_envs, test_envs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AtariEnvFactory(EnvFactory): | class AtariEnvFactory(EnvFactoryGymnasium): | ||||||
|     def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0): |     def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False): | ||||||
|         self.task = task |         assert "NoFrameskip" in task | ||||||
|         self.seed = seed |  | ||||||
|         self.frame_stack = frame_stack |         self.frame_stack = frame_stack | ||||||
|         self.scale = scale |         self.scale = scale | ||||||
| 
 |         super().__init__( | ||||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments: |             task=task, | ||||||
|         env, train_envs, test_envs = make_atari_env( |             seed=seed, | ||||||
|             task=self.task, |             venv_type=VectorEnvType.SUBPROC_SHARED_MEM, | ||||||
|             seed=self.seed, |             envpool_factory=self.EnvPoolFactory(self), | ||||||
|             training_num=num_training_envs, |  | ||||||
|             test_num=num_test_envs, |  | ||||||
|             scale=self.scale, |  | ||||||
|             frame_stack=self.frame_stack, |  | ||||||
|         ) |         ) | ||||||
|         return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) | 
 | ||||||
|  |     def create_env(self, mode: EnvMode) -> Env: | ||||||
|  |         env = super().create_env(mode) | ||||||
|  |         is_train = mode == EnvMode.TRAIN | ||||||
|  |         return wrap_deepmind( | ||||||
|  |             env, | ||||||
|  |             episode_life=is_train, | ||||||
|  |             clip_rewards=is_train, | ||||||
|  |             frame_stack=self.frame_stack, | ||||||
|  |             scale=self.scale, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     class EnvPoolFactory(EnvPoolFactory): | ||||||
|  |         def __init__(self, parent: "AtariEnvFactory"): | ||||||
|  |             self.parent = parent | ||||||
|  |             if self.parent.scale: | ||||||
|  |                 warnings.warn( | ||||||
|  |                     "EnvPool does not include ScaledFloatFrame wrapper, " | ||||||
|  |                     "please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)", | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |         def _transform_task(self, task: str) -> str: | ||||||
|  |             task = super()._transform_task(task) | ||||||
|  |             return task.replace("NoFrameskip-v4", "-v5") | ||||||
|  | 
 | ||||||
|  |         def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: | ||||||
|  |             kwargs = super()._transform_kwargs(kwargs, mode) | ||||||
|  |             is_train = mode == EnvMode.TRAIN | ||||||
|  |             kwargs["reward_clip"] = is_train | ||||||
|  |             kwargs["episodic_life"] = is_train | ||||||
|  |             kwargs["stack_num"] = self.parent.frame_stack | ||||||
|  |             return kwargs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AtariStopCallback(TrainerStopCallback): | class AtariStopCallback(TrainerStopCallback): | ||||||
|  | |||||||
| @ -1,11 +1,13 @@ | |||||||
| import logging | import logging | ||||||
| import pickle | import pickle | ||||||
| import warnings |  | ||||||
| 
 | 
 | ||||||
| import gymnasium as gym | from tianshou.env import VectorEnvNormObs | ||||||
| 
 | from tianshou.highlevel.env import ( | ||||||
| from tianshou.env import ShmemVectorEnv, VectorEnvNormObs |     ContinuousEnvironments, | ||||||
| from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory |     EnvFactoryGymnasium, | ||||||
|  |     EnvPoolFactory, | ||||||
|  |     VectorEnvType, | ||||||
|  | ) | ||||||
| from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent | from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent | ||||||
| from tianshou.highlevel.world import World | from tianshou.highlevel.world import World | ||||||
| 
 | 
 | ||||||
| @ -24,25 +26,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in | |||||||
| 
 | 
 | ||||||
|     :return: a tuple of (single env, training envs, test envs). |     :return: a tuple of (single env, training envs, test envs). | ||||||
|     """ |     """ | ||||||
|     if envpool is not None: |     envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs( | ||||||
|         train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed) |         num_train_envs, | ||||||
|         test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed) |         num_test_envs, | ||||||
|     else: |  | ||||||
|         warnings.warn( |  | ||||||
|             "Recommend using envpool (pip install envpool) " |  | ||||||
|             "to run Mujoco environments more efficiently.", |  | ||||||
|     ) |     ) | ||||||
|         env = gym.make(task) |     return envs.env, envs.train_envs, envs.test_envs | ||||||
|         train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) |  | ||||||
|         test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) |  | ||||||
|         train_envs.seed(seed) |  | ||||||
|         test_envs.seed(seed) |  | ||||||
|     if obs_norm: |  | ||||||
|         # obs norm wrapper |  | ||||||
|         train_envs = VectorEnvNormObs(train_envs) |  | ||||||
|         test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) |  | ||||||
|         test_envs.set_obs_rms(train_envs.get_obs_rms()) |  | ||||||
|     return env, train_envs, test_envs |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MujocoEnvObsRmsPersistence(Persistence): | class MujocoEnvObsRmsPersistence(Persistence): | ||||||
| @ -68,21 +56,25 @@ class MujocoEnvObsRmsPersistence(Persistence): | |||||||
|         world.envs.test_envs.set_obs_rms(obs_rms) |         world.envs.test_envs.set_obs_rms(obs_rms) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MujocoEnvFactory(EnvFactory): | class MujocoEnvFactory(EnvFactoryGymnasium): | ||||||
|     def __init__(self, task: str, seed: int, obs_norm=True): |     def __init__(self, task: str, seed: int, obs_norm=True): | ||||||
|         self.task = task |         super().__init__( | ||||||
|         self.seed = seed |             task=task, | ||||||
|  |             seed=seed, | ||||||
|  |             venv_type=VectorEnvType.SUBPROC_SHARED_MEM, | ||||||
|  |             envpool_factory=EnvPoolFactory(), | ||||||
|  |         ) | ||||||
|         self.obs_norm = obs_norm |         self.obs_norm = obs_norm | ||||||
| 
 | 
 | ||||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: |     def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: | ||||||
|         env, train_envs, test_envs = make_mujoco_env( |         envs = super().create_envs(num_training_envs, num_test_envs) | ||||||
|             task=self.task, |         assert isinstance(envs, ContinuousEnvironments) | ||||||
|             seed=self.seed, | 
 | ||||||
|             num_train_envs=num_training_envs, |         # obs norm wrapper | ||||||
|             num_test_envs=num_test_envs, |  | ||||||
|             obs_norm=self.obs_norm, |  | ||||||
|         ) |  | ||||||
|         envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) |  | ||||||
|         if self.obs_norm: |         if self.obs_norm: | ||||||
|  |             envs.train_envs = VectorEnvNormObs(envs.train_envs) | ||||||
|  |             envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False) | ||||||
|  |             envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) | ||||||
|             envs.set_persistence(MujocoEnvObsRmsPersistence()) |             envs.set_persistence(MujocoEnvObsRmsPersistence()) | ||||||
|  | 
 | ||||||
|         return envs |         return envs | ||||||
|  | |||||||
| @ -1,27 +1,14 @@ | |||||||
| import gymnasium as gym |  | ||||||
| 
 |  | ||||||
| from tianshou.env import DummyVectorEnv |  | ||||||
| from tianshou.highlevel.env import ( | from tianshou.highlevel.env import ( | ||||||
|     ContinuousEnvironments, |     EnvFactoryGymnasium, | ||||||
|     DiscreteEnvironments, |     VectorEnvType, | ||||||
|     EnvFactory, |  | ||||||
|     Environments, |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DiscreteTestEnvFactory(EnvFactory): | class DiscreteTestEnvFactory(EnvFactoryGymnasium): | ||||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: |     def __init__(self): | ||||||
|         task = "CartPole-v0" |         super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY) | ||||||
|         env = gym.make(task) |  | ||||||
|         train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) |  | ||||||
|         test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) |  | ||||||
|         return DiscreteEnvironments(env, train_envs, test_envs) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ContinuousTestEnvFactory(EnvFactory): | class ContinuousTestEnvFactory(EnvFactoryGymnasium): | ||||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: |     def __init__(self): | ||||||
|         task = "Pendulum-v1" |         super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY) | ||||||
|         env = gym.make(task) |  | ||||||
|         train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) |  | ||||||
|         test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) |  | ||||||
|         return ContinuousEnvironments(env, train_envs, test_envs) |  | ||||||
|  | |||||||
| @ -1,9 +1,12 @@ | |||||||
|  | import logging | ||||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||||
| from collections.abc import Callable, Sequence | from collections.abc import Callable, Sequence | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from typing import Any, TypeAlias, cast | from typing import Any, TypeAlias, cast | ||||||
| 
 | 
 | ||||||
| import gymnasium as gym | import gymnasium as gym | ||||||
|  | import gymnasium.spaces | ||||||
|  | from gymnasium import Env | ||||||
| 
 | 
 | ||||||
| from tianshou.env import ( | from tianshou.env import ( | ||||||
|     BaseVectorEnv, |     BaseVectorEnv, | ||||||
| @ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin | |||||||
| 
 | 
 | ||||||
| TObservationShape: TypeAlias = int | Sequence[int] | TObservationShape: TypeAlias = int | Sequence[int] | ||||||
| 
 | 
 | ||||||
|  | log = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class EnvType(Enum): | class EnvType(Enum): | ||||||
|     """Enumeration of environment types.""" |     """Enumeration of environment types.""" | ||||||
| @ -39,6 +44,23 @@ class EnvType(Enum): | |||||||
|         if not self.is_discrete(): |         if not self.is_discrete(): | ||||||
|             raise AssertionError(f"{requiring_entity} requires discrete environments") |             raise AssertionError(f"{requiring_entity} requires discrete environments") | ||||||
| 
 | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def from_env(env: Env) -> "EnvType": | ||||||
|  |         if isinstance(env.action_space, gymnasium.spaces.Discrete): | ||||||
|  |             return EnvType.DISCRETE | ||||||
|  |         elif isinstance(env.action_space, gymnasium.spaces.Box): | ||||||
|  |             return EnvType.CONTINUOUS | ||||||
|  |         else: | ||||||
|  |             raise Exception(f"Unsupported environment type with action space {env.action_space}") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EnvMode(Enum): | ||||||
|  |     """Indicates the purpose for which an environment is created.""" | ||||||
|  | 
 | ||||||
|  |     TRAIN = "train" | ||||||
|  |     TEST = "test" | ||||||
|  |     WATCH = "watch" | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class VectorEnvType(Enum): | class VectorEnvType(Enum): | ||||||
|     DUMMY = "dummy" |     DUMMY = "dummy" | ||||||
| @ -65,7 +87,7 @@ class VectorEnvType(Enum): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Environments(ToStringMixin, ABC): | class Environments(ToStringMixin, ABC): | ||||||
|     """Represents (vectorized) environments.""" |     """Represents (vectorized) environments for a learning process.""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): |     def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): | ||||||
|         self.env = env |         self.env = env | ||||||
| @ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC): | |||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_factory_and_type( |     def from_factory_and_type( | ||||||
|         factory_fn: Callable[[], gym.Env], |         factory_fn: Callable[[EnvMode], gym.Env], | ||||||
|         env_type: EnvType, |         env_type: EnvType, | ||||||
|         venv_type: VectorEnvType, |         venv_type: VectorEnvType, | ||||||
|         num_training_envs: int, |         num_training_envs: int, | ||||||
|         num_test_envs: int, |         num_test_envs: int, | ||||||
|         test_factory_fn: Callable[[], gym.Env] | None = None, |  | ||||||
|     ) -> "Environments": |     ) -> "Environments": | ||||||
|         """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). |         """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). | ||||||
| 
 | 
 | ||||||
| @ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC): | |||||||
|         :param venv_type: the vector environment type to use for parallelization |         :param venv_type: the vector environment type to use for parallelization | ||||||
|         :param num_training_envs: the number of training environments to create |         :param num_training_envs: the number of training environments to create | ||||||
|         :param num_test_envs: the number of test environments to create |         :param num_test_envs: the number of test environments to create | ||||||
|         :param test_factory_fn: the factory to use for the creation of test environment instances; |  | ||||||
|             if None, use `factory_fn` for all environments (train and test) |  | ||||||
|         :return: the instance |         :return: the instance | ||||||
|         """ |         """ | ||||||
|         if test_factory_fn is None: |         train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) | ||||||
|             test_factory_fn = factory_fn |         test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) | ||||||
|         train_envs = venv_type.create_venv([factory_fn] * num_training_envs) |         env = factory_fn(EnvMode.TRAIN) | ||||||
|         test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs) |  | ||||||
|         env = factory_fn() |  | ||||||
|         match env_type: |         match env_type: | ||||||
|             case EnvType.CONTINUOUS: |             case EnvType.CONTINUOUS: | ||||||
|                 return ContinuousEnvironments(env, train_envs, test_envs) |                 return ContinuousEnvironments(env, train_envs, test_envs) | ||||||
| @ -153,11 +170,10 @@ class ContinuousEnvironments(Environments): | |||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_factory( |     def from_factory( | ||||||
|         factory_fn: Callable[[], gym.Env], |         factory_fn: Callable[[EnvMode], gym.Env], | ||||||
|         venv_type: VectorEnvType, |         venv_type: VectorEnvType, | ||||||
|         num_training_envs: int, |         num_training_envs: int, | ||||||
|         num_test_envs: int, |         num_test_envs: int, | ||||||
|         test_factory_fn: Callable[[], gym.Env] | None = None, |  | ||||||
|     ) -> "ContinuousEnvironments": |     ) -> "ContinuousEnvironments": | ||||||
|         """Creates an instance from a factory function that creates a single instance. |         """Creates an instance from a factory function that creates a single instance. | ||||||
| 
 | 
 | ||||||
| @ -165,8 +181,6 @@ class ContinuousEnvironments(Environments): | |||||||
|         :param venv_type: the vector environment type to use for parallelization |         :param venv_type: the vector environment type to use for parallelization | ||||||
|         :param num_training_envs: the number of training environments to create |         :param num_training_envs: the number of training environments to create | ||||||
|         :param num_test_envs: the number of test environments to create |         :param num_test_envs: the number of test environments to create | ||||||
|         :param test_factory_fn: the factory to use for the creation of test environment instances; |  | ||||||
|             if None, use `factory_fn` for all environments (train and test) |  | ||||||
|         :return: the instance |         :return: the instance | ||||||
|         """ |         """ | ||||||
|         return cast( |         return cast( | ||||||
| @ -177,7 +191,6 @@ class ContinuousEnvironments(Environments): | |||||||
|                 venv_type, |                 venv_type, | ||||||
|                 num_training_envs, |                 num_training_envs, | ||||||
|                 num_test_envs, |                 num_test_envs, | ||||||
|                 test_factory_fn=test_factory_fn, |  | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -222,11 +235,10 @@ class DiscreteEnvironments(Environments): | |||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def from_factory( |     def from_factory( | ||||||
|         factory_fn: Callable[[], gym.Env], |         factory_fn: Callable[[EnvMode], gym.Env], | ||||||
|         venv_type: VectorEnvType, |         venv_type: VectorEnvType, | ||||||
|         num_training_envs: int, |         num_training_envs: int, | ||||||
|         num_test_envs: int, |         num_test_envs: int, | ||||||
|         test_factory_fn: Callable[[], gym.Env] | None = None, |  | ||||||
|     ) -> "DiscreteEnvironments": |     ) -> "DiscreteEnvironments": | ||||||
|         """Creates an instance from a factory function that creates a single instance. |         """Creates an instance from a factory function that creates a single instance. | ||||||
| 
 | 
 | ||||||
| @ -234,8 +246,6 @@ class DiscreteEnvironments(Environments): | |||||||
|         :param venv_type: the vector environment type to use for parallelization |         :param venv_type: the vector environment type to use for parallelization | ||||||
|         :param num_training_envs: the number of training environments to create |         :param num_training_envs: the number of training environments to create | ||||||
|         :param num_test_envs: the number of test environments to create |         :param num_test_envs: the number of test environments to create | ||||||
|         :param test_factory_fn: the factory to use for the creation of test environment instances; |  | ||||||
|             if None, use `factory_fn` for all environments (train and test) |  | ||||||
|         :return: the instance |         :return: the instance | ||||||
|         """ |         """ | ||||||
|         return cast( |         return cast( | ||||||
| @ -246,7 +256,6 @@ class DiscreteEnvironments(Environments): | |||||||
|                 venv_type, |                 venv_type, | ||||||
|                 num_training_envs, |                 num_training_envs, | ||||||
|                 num_test_envs, |                 num_test_envs, | ||||||
|                 test_factory_fn=test_factory_fn, |  | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -260,7 +269,156 @@ class DiscreteEnvironments(Environments): | |||||||
|         return EnvType.DISCRETE |         return EnvType.DISCRETE | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class EnvPoolFactory: | ||||||
|  |     """A factory for the creation of envpool-based vectorized environments.""" | ||||||
|  | 
 | ||||||
|  |     def _transform_task(self, task: str) -> str: | ||||||
|  |         return task | ||||||
|  | 
 | ||||||
|  |     def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: | ||||||
|  |         """Transforms gymnasium keyword arguments to be envpool-compatible. | ||||||
|  | 
 | ||||||
|  |         :param kwargs: keyword arguments that would normally be passed to `gymnasium.make`. | ||||||
|  |         :param mode: the environment mode | ||||||
|  |         :return: the transformed keyword arguments | ||||||
|  |         """ | ||||||
|  |         kwargs = dict(kwargs) | ||||||
|  |         if "render_mode" in kwargs: | ||||||
|  |             del kwargs["render_mode"] | ||||||
|  |         return kwargs | ||||||
|  | 
 | ||||||
|  |     def create_venv( | ||||||
|  |         self, | ||||||
|  |         task: str, | ||||||
|  |         num_envs: int, | ||||||
|  |         mode: EnvMode, | ||||||
|  |         seed: int, | ||||||
|  |         kwargs: dict, | ||||||
|  |     ) -> BaseVectorEnv | None: | ||||||
|  |         try: | ||||||
|  |             import envpool | ||||||
|  | 
 | ||||||
|  |             envpool_task = self._transform_task(task) | ||||||
|  |             envpool_kwargs = self._transform_kwargs(kwargs, mode) | ||||||
|  |             return envpool.make_gymnasium( | ||||||
|  |                 envpool_task, | ||||||
|  |                 num_envs=num_envs, | ||||||
|  |                 seed=seed, | ||||||
|  |                 **envpool_kwargs, | ||||||
|  |             ) | ||||||
|  |         except ImportError: | ||||||
|  |             return None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class EnvFactory(ToStringMixin, ABC): | class EnvFactory(ToStringMixin, ABC): | ||||||
|  |     def __init__(self, venv_type: VectorEnvType): | ||||||
|  |         """:param venv_type: the type of vectorized environment to use""" | ||||||
|  |         self.venv_type = venv_type | ||||||
|  | 
 | ||||||
|     @abstractmethod |     @abstractmethod | ||||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: |     def create_env(self, mode: EnvMode) -> Env: | ||||||
|         pass |         pass | ||||||
|  | 
 | ||||||
|  |     def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: | ||||||
|  |         """Create vectorized environments. | ||||||
|  | 
 | ||||||
|  |         :param num_envs: the number of environments | ||||||
|  |         :param mode: the mode for which to create | ||||||
|  |         :return: the vectorized environments | ||||||
|  |         """ | ||||||
|  |         return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) | ||||||
|  | 
 | ||||||
|  |     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: | ||||||
|  |         """Create environments for learning. | ||||||
|  | 
 | ||||||
|  |         :param num_training_envs: the number of training environments | ||||||
|  |         :param num_test_envs: the number of test environments | ||||||
|  |         :return: the environments | ||||||
|  |         """ | ||||||
|  |         env = self.create_env(EnvMode.TRAIN) | ||||||
|  |         train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) | ||||||
|  |         test_envs = self.create_venv(num_test_envs, EnvMode.TEST) | ||||||
|  |         match EnvType.from_env(env): | ||||||
|  |             case EnvType.DISCRETE: | ||||||
|  |                 return DiscreteEnvironments(env, train_envs, test_envs) | ||||||
|  |             case EnvType.CONTINUOUS: | ||||||
|  |                 return ContinuousEnvironments(env, train_envs, test_envs) | ||||||
|  |             case _: | ||||||
|  |                 raise ValueError | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EnvFactoryGymnasium(EnvFactory): | ||||||
|  |     """Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`).""" | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         *, | ||||||
|  |         task: str, | ||||||
|  |         seed: int, | ||||||
|  |         venv_type: VectorEnvType, | ||||||
|  |         envpool_factory: EnvPoolFactory | None = None, | ||||||
|  |         render_mode_train: str | None = None, | ||||||
|  |         render_mode_test: str | None = None, | ||||||
|  |         render_mode_watch: str = "human", | ||||||
|  |         **kwargs: Any, | ||||||
|  |     ): | ||||||
|  |         """:param task: the gymnasium task/environment identifier | ||||||
|  |         :param seed: the random seed | ||||||
|  |         :param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback. | ||||||
|  |         :param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed. | ||||||
|  |             If it is not installed, `venv_type` applies as a fallback. | ||||||
|  |         :param render_mode_train: the render mode to use for training environments | ||||||
|  |         :param render_mode_test: the render mode to use for test environments | ||||||
|  |         :param render_mode_watch: the render mode to use for environments that are used to watch agent performance | ||||||
|  |         :param kwargs: additional keyword arguments to pass on to `gymnasium.make`. | ||||||
|  |             If envpool is used, the gymnasium parameters will be appropriately translated for use with | ||||||
|  |             `envpool.make_gymnasium`. | ||||||
|  |         """ | ||||||
|  |         super().__init__(venv_type) | ||||||
|  |         self.task = task | ||||||
|  |         self.envpool_factory = envpool_factory | ||||||
|  |         self.seed = seed | ||||||
|  |         self.render_modes = { | ||||||
|  |             EnvMode.TRAIN: render_mode_train, | ||||||
|  |             EnvMode.TEST: render_mode_test, | ||||||
|  |             EnvMode.WATCH: render_mode_watch, | ||||||
|  |         } | ||||||
|  |         self.kwargs = kwargs | ||||||
|  | 
 | ||||||
|  |     def _create_kwargs(self, mode: EnvMode) -> dict: | ||||||
|  |         """Adapts the keyword arguments for the given mode. | ||||||
|  | 
 | ||||||
|  |         :param mode: the mode | ||||||
|  |         :return: adapted keyword arguments | ||||||
|  |         """ | ||||||
|  |         kwargs = dict(self.kwargs) | ||||||
|  |         kwargs["render_mode"] = self.render_modes.get(mode) | ||||||
|  |         return kwargs | ||||||
|  | 
 | ||||||
|  |     def create_env(self, mode: EnvMode) -> Env: | ||||||
|  |         """Creates a single environment for the given mode. | ||||||
|  | 
 | ||||||
|  |         :param mode: the mode | ||||||
|  |         :return: an environment | ||||||
|  |         """ | ||||||
|  |         kwargs = self._create_kwargs(mode) | ||||||
|  |         return gymnasium.make(self.task, **kwargs) | ||||||
|  | 
 | ||||||
|  |     def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: | ||||||
|  |         if self.envpool_factory is not None: | ||||||
|  |             venv = self.envpool_factory.create_venv( | ||||||
|  |                 self.task, | ||||||
|  |                 num_envs, | ||||||
|  |                 mode, | ||||||
|  |                 self.seed, | ||||||
|  |                 self._create_kwargs(mode), | ||||||
|  |             ) | ||||||
|  |             if venv is not None: | ||||||
|  |                 return venv | ||||||
|  |             log.debug( | ||||||
|  |                 f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}", | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         venv = super().create_venv(num_envs, mode) | ||||||
|  |         venv.seed(self.seed) | ||||||
|  |         return venv | ||||||
|  | |||||||
| @ -26,7 +26,7 @@ from tianshou.highlevel.agent import ( | |||||||
|     TRPOAgentFactory, |     TRPOAgentFactory, | ||||||
| ) | ) | ||||||
| from tianshou.highlevel.config import SamplingConfig | from tianshou.highlevel.config import SamplingConfig | ||||||
| from tianshou.highlevel.env import EnvFactory | from tianshou.highlevel.env import EnvFactory, EnvMode | ||||||
| from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger | from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger | ||||||
| from tianshou.highlevel.module.actor import ( | from tianshou.highlevel.module.actor import ( | ||||||
|     ActorFactory, |     ActorFactory, | ||||||
| @ -293,7 +293,7 @@ class Experiment(ToStringMixin): | |||||||
|                 self._watch_agent( |                 self._watch_agent( | ||||||
|                     self.config.watch_num_episodes, |                     self.config.watch_num_episodes, | ||||||
|                     policy, |                     policy, | ||||||
|                     test_collector, |                     self.env_factory, | ||||||
|                     self.config.watch_render, |                     self.config.watch_render, | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
| @ -303,15 +303,18 @@ class Experiment(ToStringMixin): | |||||||
|     def _watch_agent( |     def _watch_agent( | ||||||
|         num_episodes: int, |         num_episodes: int, | ||||||
|         policy: BasePolicy, |         policy: BasePolicy, | ||||||
|         test_collector: Collector, |         env_factory: EnvFactory, | ||||||
|         render: float, |         render: float, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         policy.eval() |         policy.eval() | ||||||
|         test_collector.reset() |         env = env_factory.create_env(EnvMode.WATCH) | ||||||
|         result = test_collector.collect(n_episode=num_episodes, render=render) |         collector = Collector(policy, env) | ||||||
|  |         result = collector.collect(n_episode=num_episodes, render=render) | ||||||
|         assert result.returns_stat is not None  # for mypy |         assert result.returns_stat is not None  # for mypy | ||||||
|         assert result.lens_stat is not None  # for mypy |         assert result.lens_stat is not None  # for mypy | ||||||
|         print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") |         log.info( | ||||||
|  |             f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}", | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ExperimentBuilder: | class ExperimentBuilder: | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user