diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 2018d1f..a7d7ea5 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -7,9 +7,15 @@ from collections import deque import cv2 import gymnasium as gym import numpy as np +from gymnasium import Env from tianshou.env import ShmemVectorEnv -from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory +from tianshou.highlevel.env import ( + EnvFactoryGymnasium, + EnvMode, + EnvPoolFactory, + VectorEnvType, +) from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext try: @@ -282,7 +288,7 @@ class FrameStack(gym.Wrapper): def wrap_deepmind( - env_id, + env: Env, episode_life=True, clip_rewards=True, frame_stack=4, @@ -293,7 +299,7 @@ def wrap_deepmind( The observation is channel-first: (c, h, w) instead of (h, w, c). - :param str env_id: the atari environment id. + :param env: the Atari environment to wrap. :param bool episode_life: wrap the episode life wrapper. :param bool clip_rewards: wrap the reward clipping wrapper. :param int frame_stack: wrap the frame stacking wrapper. @@ -301,8 +307,6 @@ def wrap_deepmind( :param bool warp_frame: wrap the grayscale + resize observation wrapper. :return: the wrapped atari environment. """ - assert "NoFrameskip" in env_id - env = gym.make(env_id) env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) if episode_life: @@ -351,19 +355,30 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): stack_num=kwargs.get("frame_stack", 4), ) else: + assert "NoFrameskip" in task warnings.warn( "Recommend using envpool (pip install envpool) to run Atari games more efficiently.", ) env = wrap_deepmind(task, **kwargs) train_envs = ShmemVectorEnv( [ - lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs) + lambda: wrap_deepmind( + gym.make(task), + episode_life=True, + clip_rewards=True, + **kwargs, + ) for _ in range(training_num) ], ) test_envs = ShmemVectorEnv( [ - lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs) + lambda: wrap_deepmind( + gym.make(task), + episode_life=False, + clip_rewards=False, + **kwargs, + ) for _ in range(test_num) ], ) @@ -373,23 +388,49 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs): return env, train_envs, test_envs -class AtariEnvFactory(EnvFactory): - def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0): - self.task = task - self.seed = seed +class AtariEnvFactory(EnvFactoryGymnasium): + def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False): + assert "NoFrameskip" in task self.frame_stack = frame_stack self.scale = scale - - def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments: - env, train_envs, test_envs = make_atari_env( - task=self.task, - seed=self.seed, - training_num=num_training_envs, - test_num=num_test_envs, - scale=self.scale, - frame_stack=self.frame_stack, + super().__init__( + task=task, + seed=seed, + venv_type=VectorEnvType.SUBPROC_SHARED_MEM, + envpool_factory=self.EnvPoolFactory(self), ) - return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) + + def create_env(self, mode: EnvMode) -> Env: + env = super().create_env(mode) + is_train = mode == EnvMode.TRAIN + return wrap_deepmind( + env, + episode_life=is_train, + clip_rewards=is_train, + frame_stack=self.frame_stack, + scale=self.scale, + ) + + class EnvPoolFactory(EnvPoolFactory): + def __init__(self, parent: "AtariEnvFactory"): + self.parent = parent + if self.parent.scale: + warnings.warn( + "EnvPool does not include ScaledFloatFrame wrapper, " + "please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)", + ) + + def _transform_task(self, task: str) -> str: + task = super()._transform_task(task) + return task.replace("NoFrameskip-v4", "-v5") + + def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: + kwargs = super()._transform_kwargs(kwargs, mode) + is_train = mode == EnvMode.TRAIN + kwargs["reward_clip"] = is_train + kwargs["episodic_life"] = is_train + kwargs["stack_num"] = self.parent.frame_stack + return kwargs class AtariStopCallback(TrainerStopCallback): diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 3a48121..68a0a12 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,11 +1,13 @@ import logging import pickle -import warnings -import gymnasium as gym - -from tianshou.env import ShmemVectorEnv, VectorEnvNormObs -from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory +from tianshou.env import VectorEnvNormObs +from tianshou.highlevel.env import ( + ContinuousEnvironments, + EnvFactoryGymnasium, + EnvPoolFactory, + VectorEnvType, +) from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.world import World @@ -24,25 +26,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in :return: a tuple of (single env, training envs, test envs). """ - if envpool is not None: - train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed) - test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed) - else: - warnings.warn( - "Recommend using envpool (pip install envpool) " - "to run Mujoco environments more efficiently.", - ) - env = gym.make(task) - train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) - test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) - train_envs.seed(seed) - test_envs.seed(seed) - if obs_norm: - # obs norm wrapper - train_envs = VectorEnvNormObs(train_envs) - test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) - test_envs.set_obs_rms(train_envs.get_obs_rms()) - return env, train_envs, test_envs + envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs( + num_train_envs, + num_test_envs, + ) + return envs.env, envs.train_envs, envs.test_envs class MujocoEnvObsRmsPersistence(Persistence): @@ -68,21 +56,25 @@ class MujocoEnvObsRmsPersistence(Persistence): world.envs.test_envs.set_obs_rms(obs_rms) -class MujocoEnvFactory(EnvFactory): +class MujocoEnvFactory(EnvFactoryGymnasium): def __init__(self, task: str, seed: int, obs_norm=True): - self.task = task - self.seed = seed + super().__init__( + task=task, + seed=seed, + venv_type=VectorEnvType.SUBPROC_SHARED_MEM, + envpool_factory=EnvPoolFactory(), + ) self.obs_norm = obs_norm def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: - env, train_envs, test_envs = make_mujoco_env( - task=self.task, - seed=self.seed, - num_train_envs=num_training_envs, - num_test_envs=num_test_envs, - obs_norm=self.obs_norm, - ) - envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) + envs = super().create_envs(num_training_envs, num_test_envs) + assert isinstance(envs, ContinuousEnvironments) + + # obs norm wrapper if self.obs_norm: + envs.train_envs = VectorEnvNormObs(envs.train_envs) + envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False) + envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) envs.set_persistence(MujocoEnvObsRmsPersistence()) + return envs diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 8ed89b1..760bb5b 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -1,27 +1,14 @@ -import gymnasium as gym - -from tianshou.env import DummyVectorEnv from tianshou.highlevel.env import ( - ContinuousEnvironments, - DiscreteEnvironments, - EnvFactory, - Environments, + EnvFactoryGymnasium, + VectorEnvType, ) -class DiscreteTestEnvFactory(EnvFactory): - def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: - task = "CartPole-v0" - env = gym.make(task) - train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) - return DiscreteEnvironments(env, train_envs, test_envs) +class DiscreteTestEnvFactory(EnvFactoryGymnasium): + def __init__(self): + super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY) -class ContinuousTestEnvFactory(EnvFactory): - def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: - task = "Pendulum-v1" - env = gym.make(task) - train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)]) - test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) - return ContinuousEnvironments(env, train_envs, test_envs) +class ContinuousTestEnvFactory(EnvFactoryGymnasium): + def __init__(self): + super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 9284d55..915f9ad 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -1,9 +1,12 @@ +import logging from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from enum import Enum from typing import Any, TypeAlias, cast import gymnasium as gym +import gymnasium.spaces +from gymnasium import Env from tianshou.env import ( BaseVectorEnv, @@ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin TObservationShape: TypeAlias = int | Sequence[int] +log = logging.getLogger(__name__) + class EnvType(Enum): """Enumeration of environment types.""" @@ -39,6 +44,23 @@ class EnvType(Enum): if not self.is_discrete(): raise AssertionError(f"{requiring_entity} requires discrete environments") + @staticmethod + def from_env(env: Env) -> "EnvType": + if isinstance(env.action_space, gymnasium.spaces.Discrete): + return EnvType.DISCRETE + elif isinstance(env.action_space, gymnasium.spaces.Box): + return EnvType.CONTINUOUS + else: + raise Exception(f"Unsupported environment type with action space {env.action_space}") + + +class EnvMode(Enum): + """Indicates the purpose for which an environment is created.""" + + TRAIN = "train" + TEST = "test" + WATCH = "watch" + class VectorEnvType(Enum): DUMMY = "dummy" @@ -65,7 +87,7 @@ class VectorEnvType(Enum): class Environments(ToStringMixin, ABC): - """Represents (vectorized) environments.""" + """Represents (vectorized) environments for a learning process.""" def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env @@ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC): @staticmethod def from_factory_and_type( - factory_fn: Callable[[], gym.Env], + factory_fn: Callable[[EnvMode], gym.Env], env_type: EnvType, venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, - test_factory_fn: Callable[[], gym.Env] | None = None, ) -> "Environments": """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). @@ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC): :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create - :param test_factory_fn: the factory to use for the creation of test environment instances; - if None, use `factory_fn` for all environments (train and test) :return: the instance """ - if test_factory_fn is None: - test_factory_fn = factory_fn - train_envs = venv_type.create_venv([factory_fn] * num_training_envs) - test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs) - env = factory_fn() + train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) + test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) + env = factory_fn(EnvMode.TRAIN) match env_type: case EnvType.CONTINUOUS: return ContinuousEnvironments(env, train_envs, test_envs) @@ -153,11 +170,10 @@ class ContinuousEnvironments(Environments): @staticmethod def from_factory( - factory_fn: Callable[[], gym.Env], + factory_fn: Callable[[EnvMode], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, - test_factory_fn: Callable[[], gym.Env] | None = None, ) -> "ContinuousEnvironments": """Creates an instance from a factory function that creates a single instance. @@ -165,8 +181,6 @@ class ContinuousEnvironments(Environments): :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create - :param test_factory_fn: the factory to use for the creation of test environment instances; - if None, use `factory_fn` for all environments (train and test) :return: the instance """ return cast( @@ -177,7 +191,6 @@ class ContinuousEnvironments(Environments): venv_type, num_training_envs, num_test_envs, - test_factory_fn=test_factory_fn, ), ) @@ -222,11 +235,10 @@ class DiscreteEnvironments(Environments): @staticmethod def from_factory( - factory_fn: Callable[[], gym.Env], + factory_fn: Callable[[EnvMode], gym.Env], venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, - test_factory_fn: Callable[[], gym.Env] | None = None, ) -> "DiscreteEnvironments": """Creates an instance from a factory function that creates a single instance. @@ -234,8 +246,6 @@ class DiscreteEnvironments(Environments): :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create - :param test_factory_fn: the factory to use for the creation of test environment instances; - if None, use `factory_fn` for all environments (train and test) :return: the instance """ return cast( @@ -246,7 +256,6 @@ class DiscreteEnvironments(Environments): venv_type, num_training_envs, num_test_envs, - test_factory_fn=test_factory_fn, ), ) @@ -260,7 +269,156 @@ class DiscreteEnvironments(Environments): return EnvType.DISCRETE +class EnvPoolFactory: + """A factory for the creation of envpool-based vectorized environments.""" + + def _transform_task(self, task: str) -> str: + return task + + def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: + """Transforms gymnasium keyword arguments to be envpool-compatible. + + :param kwargs: keyword arguments that would normally be passed to `gymnasium.make`. + :param mode: the environment mode + :return: the transformed keyword arguments + """ + kwargs = dict(kwargs) + if "render_mode" in kwargs: + del kwargs["render_mode"] + return kwargs + + def create_venv( + self, + task: str, + num_envs: int, + mode: EnvMode, + seed: int, + kwargs: dict, + ) -> BaseVectorEnv | None: + try: + import envpool + + envpool_task = self._transform_task(task) + envpool_kwargs = self._transform_kwargs(kwargs, mode) + return envpool.make_gymnasium( + envpool_task, + num_envs=num_envs, + seed=seed, + **envpool_kwargs, + ) + except ImportError: + return None + + class EnvFactory(ToStringMixin, ABC): + def __init__(self, venv_type: VectorEnvType): + """:param venv_type: the type of vectorized environment to use""" + self.venv_type = venv_type + @abstractmethod - def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: + def create_env(self, mode: EnvMode) -> Env: pass + + def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + """Create vectorized environments. + + :param num_envs: the number of environments + :param mode: the mode for which to create + :return: the vectorized environments + """ + return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) + + def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: + """Create environments for learning. + + :param num_training_envs: the number of training environments + :param num_test_envs: the number of test environments + :return: the environments + """ + env = self.create_env(EnvMode.TRAIN) + train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) + test_envs = self.create_venv(num_test_envs, EnvMode.TEST) + match EnvType.from_env(env): + case EnvType.DISCRETE: + return DiscreteEnvironments(env, train_envs, test_envs) + case EnvType.CONTINUOUS: + return ContinuousEnvironments(env, train_envs, test_envs) + case _: + raise ValueError + + +class EnvFactoryGymnasium(EnvFactory): + """Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`).""" + + def __init__( + self, + *, + task: str, + seed: int, + venv_type: VectorEnvType, + envpool_factory: EnvPoolFactory | None = None, + render_mode_train: str | None = None, + render_mode_test: str | None = None, + render_mode_watch: str = "human", + **kwargs: Any, + ): + """:param task: the gymnasium task/environment identifier + :param seed: the random seed + :param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback. + :param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed. + If it is not installed, `venv_type` applies as a fallback. + :param render_mode_train: the render mode to use for training environments + :param render_mode_test: the render mode to use for test environments + :param render_mode_watch: the render mode to use for environments that are used to watch agent performance + :param kwargs: additional keyword arguments to pass on to `gymnasium.make`. + If envpool is used, the gymnasium parameters will be appropriately translated for use with + `envpool.make_gymnasium`. + """ + super().__init__(venv_type) + self.task = task + self.envpool_factory = envpool_factory + self.seed = seed + self.render_modes = { + EnvMode.TRAIN: render_mode_train, + EnvMode.TEST: render_mode_test, + EnvMode.WATCH: render_mode_watch, + } + self.kwargs = kwargs + + def _create_kwargs(self, mode: EnvMode) -> dict: + """Adapts the keyword arguments for the given mode. + + :param mode: the mode + :return: adapted keyword arguments + """ + kwargs = dict(self.kwargs) + kwargs["render_mode"] = self.render_modes.get(mode) + return kwargs + + def create_env(self, mode: EnvMode) -> Env: + """Creates a single environment for the given mode. + + :param mode: the mode + :return: an environment + """ + kwargs = self._create_kwargs(mode) + return gymnasium.make(self.task, **kwargs) + + def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + if self.envpool_factory is not None: + venv = self.envpool_factory.create_venv( + self.task, + num_envs, + mode, + self.seed, + self._create_kwargs(mode), + ) + if venv is not None: + return venv + log.debug( + f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}", + ) + + venv = super().create_venv(num_envs, mode) + venv.seed(self.seed) + return venv diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 2d5a880..bac469f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -26,7 +26,7 @@ from tianshou.highlevel.agent import ( TRPOAgentFactory, ) from tianshou.highlevel.config import SamplingConfig -from tianshou.highlevel.env import EnvFactory +from tianshou.highlevel.env import EnvFactory, EnvMode from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.module.actor import ( ActorFactory, @@ -293,7 +293,7 @@ class Experiment(ToStringMixin): self._watch_agent( self.config.watch_num_episodes, policy, - test_collector, + self.env_factory, self.config.watch_render, ) @@ -303,15 +303,18 @@ class Experiment(ToStringMixin): def _watch_agent( num_episodes: int, policy: BasePolicy, - test_collector: Collector, + env_factory: EnvFactory, render: float, ) -> None: policy.eval() - test_collector.reset() - result = test_collector.collect(n_episode=num_episodes, render=render) + env = env_factory.create_env(EnvMode.WATCH) + collector = Collector(policy, env) + result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + log.info( + f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}", + ) class ExperimentBuilder: