diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index f1b3120..6e6c4e5 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -1,6 +1,6 @@ # Borrow a lot from openai baselines: # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py - +import logging import warnings from collections import deque @@ -17,10 +17,13 @@ from tianshou.highlevel.env import ( ) from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext +envpool_is_available = True try: import envpool except ImportError: + envpool_is_available = False envpool = None +log = logging.getLogger(__name__) def _parse_reset_result(reset_result): @@ -343,15 +346,29 @@ def make_atari_env( class AtariEnvFactory(EnvFactoryGymnasium): - def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False): + def __init__( + self, + task: str, + seed: int, + frame_stack: int, + scale: bool = False, + use_envpool_if_available: bool = True, + ): assert "NoFrameskip" in task self.frame_stack = frame_stack self.scale = scale + envpool_factory = None + if use_envpool_if_available: + if envpool_is_available: + envpool_factory = self.EnvPoolFactory(self) + log.info("Using envpool, because it available") + else: + log.info("Not using envpool, because it is not available") super().__init__( task=task, seed=seed, venv_type=VectorEnvType.SUBPROC_SHARED_MEM, - envpool_factory=self.EnvPoolFactory(self), + envpool_factory=envpool_factory, ) def create_env(self, mode: EnvMode) -> Env: diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 68a0a12..dd752e9 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -11,9 +11,11 @@ from tianshou.highlevel.env import ( from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.world import World +envpool_is_available = True try: import envpool except ImportError: + envpool_is_available = False envpool = None log = logging.getLogger(__name__) @@ -62,7 +64,7 @@ class MujocoEnvFactory(EnvFactoryGymnasium): task=task, seed=seed, venv_type=VectorEnvType.SUBPROC_SHARED_MEM, - envpool_factory=EnvPoolFactory(), + envpool_factory=EnvPoolFactory() if envpool_is_available else None, ) self.obs_norm = obs_norm diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 915f9ad..5298c2d 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -295,19 +295,16 @@ class EnvPoolFactory: seed: int, kwargs: dict, ) -> BaseVectorEnv | None: - try: - import envpool + import envpool - envpool_task = self._transform_task(task) - envpool_kwargs = self._transform_kwargs(kwargs, mode) - return envpool.make_gymnasium( - envpool_task, - num_envs=num_envs, - seed=seed, - **envpool_kwargs, - ) - except ImportError: - return None + envpool_task = self._transform_task(task) + envpool_kwargs = self._transform_kwargs(kwargs, mode) + return envpool.make_gymnasium( + envpool_task, + num_envs=num_envs, + seed=seed, + **envpool_kwargs, + ) class EnvFactory(ToStringMixin, ABC): @@ -364,9 +361,8 @@ class EnvFactoryGymnasium(EnvFactory): ): """:param task: the gymnasium task/environment identifier :param seed: the random seed - :param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback. - :param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed. - If it is not installed, `venv_type` applies as a fallback. + :param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified) + :param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed. :param render_mode_train: the render mode to use for training environments :param render_mode_test: the render mode to use for test environments :param render_mode_watch: the render mode to use for environments that are used to watch agent performance @@ -406,19 +402,14 @@ class EnvFactoryGymnasium(EnvFactory): def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: if self.envpool_factory is not None: - venv = self.envpool_factory.create_venv( + return self.envpool_factory.create_venv( self.task, num_envs, mode, self.seed, self._create_kwargs(mode), ) - if venv is not None: - return venv - log.debug( - f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}", - ) - - venv = super().create_venv(num_envs, mode) - venv.seed(self.seed) - return venv + else: + venv = super().create_venv(num_envs, mode) + venv.seed(self.seed) + return venv