Make envpool usage configuration more explicit

This commit is contained in:
Dominik Jain 2024-01-16 12:16:46 +01:00
parent a4d7ccba26
commit c9cb41bf55
3 changed files with 39 additions and 29 deletions

View File

@ -1,6 +1,6 @@
# Borrow a lot from openai baselines: # Borrow a lot from openai baselines:
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
import logging
import warnings import warnings
from collections import deque from collections import deque
@ -17,10 +17,13 @@ from tianshou.highlevel.env import (
) )
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
envpool_is_available = True
try: try:
import envpool import envpool
except ImportError: except ImportError:
envpool_is_available = False
envpool = None envpool = None
log = logging.getLogger(__name__)
def _parse_reset_result(reset_result): def _parse_reset_result(reset_result):
@ -343,15 +346,29 @@ def make_atari_env(
class AtariEnvFactory(EnvFactoryGymnasium): class AtariEnvFactory(EnvFactoryGymnasium):
def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False): def __init__(
self,
task: str,
seed: int,
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
):
assert "NoFrameskip" in task assert "NoFrameskip" in task
self.frame_stack = frame_stack self.frame_stack = frame_stack
self.scale = scale self.scale = scale
envpool_factory = None
if use_envpool_if_available:
if envpool_is_available:
envpool_factory = self.EnvPoolFactory(self)
log.info("Using envpool, because it available")
else:
log.info("Not using envpool, because it is not available")
super().__init__( super().__init__(
task=task, task=task,
seed=seed, seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM, venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=self.EnvPoolFactory(self), envpool_factory=envpool_factory,
) )
def create_env(self, mode: EnvMode) -> Env: def create_env(self, mode: EnvMode) -> Env:

View File

@ -11,9 +11,11 @@ from tianshou.highlevel.env import (
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World from tianshou.highlevel.world import World
envpool_is_available = True
try: try:
import envpool import envpool
except ImportError: except ImportError:
envpool_is_available = False
envpool = None envpool = None
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -62,7 +64,7 @@ class MujocoEnvFactory(EnvFactoryGymnasium):
task=task, task=task,
seed=seed, seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM, venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=EnvPoolFactory(), envpool_factory=EnvPoolFactory() if envpool_is_available else None,
) )
self.obs_norm = obs_norm self.obs_norm = obs_norm

View File

@ -295,19 +295,16 @@ class EnvPoolFactory:
seed: int, seed: int,
kwargs: dict, kwargs: dict,
) -> BaseVectorEnv | None: ) -> BaseVectorEnv | None:
try: import envpool
import envpool
envpool_task = self._transform_task(task) envpool_task = self._transform_task(task)
envpool_kwargs = self._transform_kwargs(kwargs, mode) envpool_kwargs = self._transform_kwargs(kwargs, mode)
return envpool.make_gymnasium( return envpool.make_gymnasium(
envpool_task, envpool_task,
num_envs=num_envs, num_envs=num_envs,
seed=seed, seed=seed,
**envpool_kwargs, **envpool_kwargs,
) )
except ImportError:
return None
class EnvFactory(ToStringMixin, ABC): class EnvFactory(ToStringMixin, ABC):
@ -364,9 +361,8 @@ class EnvFactoryGymnasium(EnvFactory):
): ):
""":param task: the gymnasium task/environment identifier """:param task: the gymnasium task/environment identifier
:param seed: the random seed :param seed: the random seed
:param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback. :param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified)
:param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed. :param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed.
If it is not installed, `venv_type` applies as a fallback.
:param render_mode_train: the render mode to use for training environments :param render_mode_train: the render mode to use for training environments
:param render_mode_test: the render mode to use for test environments :param render_mode_test: the render mode to use for test environments
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance :param render_mode_watch: the render mode to use for environments that are used to watch agent performance
@ -406,19 +402,14 @@ class EnvFactoryGymnasium(EnvFactory):
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
if self.envpool_factory is not None: if self.envpool_factory is not None:
venv = self.envpool_factory.create_venv( return self.envpool_factory.create_venv(
self.task, self.task,
num_envs, num_envs,
mode, mode,
self.seed, self.seed,
self._create_kwargs(mode), self._create_kwargs(mode),
) )
if venv is not None: else:
return venv venv = super().create_venv(num_envs, mode)
log.debug( venv.seed(self.seed)
f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}", return venv
)
venv = super().create_venv(num_envs, mode)
venv.seed(self.seed)
return venv