Make envpool usage configuration more explicit
This commit is contained in:
parent
a4d7ccba26
commit
c9cb41bf55
@ -1,6 +1,6 @@
|
|||||||
# Borrow a lot from openai baselines:
|
# Borrow a lot from openai baselines:
|
||||||
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
@ -17,10 +17,13 @@ from tianshou.highlevel.env import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
|
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
|
||||||
|
|
||||||
|
envpool_is_available = True
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
envpool_is_available = False
|
||||||
envpool = None
|
envpool = None
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _parse_reset_result(reset_result):
|
def _parse_reset_result(reset_result):
|
||||||
@ -343,15 +346,29 @@ def make_atari_env(
|
|||||||
|
|
||||||
|
|
||||||
class AtariEnvFactory(EnvFactoryGymnasium):
|
class AtariEnvFactory(EnvFactoryGymnasium):
|
||||||
def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
seed: int,
|
||||||
|
frame_stack: int,
|
||||||
|
scale: bool = False,
|
||||||
|
use_envpool_if_available: bool = True,
|
||||||
|
):
|
||||||
assert "NoFrameskip" in task
|
assert "NoFrameskip" in task
|
||||||
self.frame_stack = frame_stack
|
self.frame_stack = frame_stack
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
envpool_factory = None
|
||||||
|
if use_envpool_if_available:
|
||||||
|
if envpool_is_available:
|
||||||
|
envpool_factory = self.EnvPoolFactory(self)
|
||||||
|
log.info("Using envpool, because it available")
|
||||||
|
else:
|
||||||
|
log.info("Not using envpool, because it is not available")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task=task,
|
task=task,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||||
envpool_factory=self.EnvPoolFactory(self),
|
envpool_factory=envpool_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_env(self, mode: EnvMode) -> Env:
|
def create_env(self, mode: EnvMode) -> Env:
|
||||||
|
@ -11,9 +11,11 @@ from tianshou.highlevel.env import (
|
|||||||
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||||
from tianshou.highlevel.world import World
|
from tianshou.highlevel.world import World
|
||||||
|
|
||||||
|
envpool_is_available = True
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
envpool_is_available = False
|
||||||
envpool = None
|
envpool = None
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -62,7 +64,7 @@ class MujocoEnvFactory(EnvFactoryGymnasium):
|
|||||||
task=task,
|
task=task,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||||
envpool_factory=EnvPoolFactory(),
|
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
|
||||||
)
|
)
|
||||||
self.obs_norm = obs_norm
|
self.obs_norm = obs_norm
|
||||||
|
|
||||||
|
@ -295,19 +295,16 @@ class EnvPoolFactory:
|
|||||||
seed: int,
|
seed: int,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
) -> BaseVectorEnv | None:
|
) -> BaseVectorEnv | None:
|
||||||
try:
|
import envpool
|
||||||
import envpool
|
|
||||||
|
|
||||||
envpool_task = self._transform_task(task)
|
envpool_task = self._transform_task(task)
|
||||||
envpool_kwargs = self._transform_kwargs(kwargs, mode)
|
envpool_kwargs = self._transform_kwargs(kwargs, mode)
|
||||||
return envpool.make_gymnasium(
|
return envpool.make_gymnasium(
|
||||||
envpool_task,
|
envpool_task,
|
||||||
num_envs=num_envs,
|
num_envs=num_envs,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
**envpool_kwargs,
|
**envpool_kwargs,
|
||||||
)
|
)
|
||||||
except ImportError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class EnvFactory(ToStringMixin, ABC):
|
class EnvFactory(ToStringMixin, ABC):
|
||||||
@ -364,9 +361,8 @@ class EnvFactoryGymnasium(EnvFactory):
|
|||||||
):
|
):
|
||||||
""":param task: the gymnasium task/environment identifier
|
""":param task: the gymnasium task/environment identifier
|
||||||
:param seed: the random seed
|
:param seed: the random seed
|
||||||
:param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback.
|
:param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified)
|
||||||
:param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed.
|
:param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed.
|
||||||
If it is not installed, `venv_type` applies as a fallback.
|
|
||||||
:param render_mode_train: the render mode to use for training environments
|
:param render_mode_train: the render mode to use for training environments
|
||||||
:param render_mode_test: the render mode to use for test environments
|
:param render_mode_test: the render mode to use for test environments
|
||||||
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
||||||
@ -406,19 +402,14 @@ class EnvFactoryGymnasium(EnvFactory):
|
|||||||
|
|
||||||
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||||
if self.envpool_factory is not None:
|
if self.envpool_factory is not None:
|
||||||
venv = self.envpool_factory.create_venv(
|
return self.envpool_factory.create_venv(
|
||||||
self.task,
|
self.task,
|
||||||
num_envs,
|
num_envs,
|
||||||
mode,
|
mode,
|
||||||
self.seed,
|
self.seed,
|
||||||
self._create_kwargs(mode),
|
self._create_kwargs(mode),
|
||||||
)
|
)
|
||||||
if venv is not None:
|
else:
|
||||||
return venv
|
venv = super().create_venv(num_envs, mode)
|
||||||
log.debug(
|
venv.seed(self.seed)
|
||||||
f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}",
|
return venv
|
||||||
)
|
|
||||||
|
|
||||||
venv = super().create_venv(num_envs, mode)
|
|
||||||
venv.seed(self.seed)
|
|
||||||
return venv
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user