diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index e8023b7..2f098b3 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -383,6 +383,11 @@ class AtariEnvFactory(EnvFactoryRegistered): ) class EnvPoolFactory(EnvPoolFactory): + """Atari-specific envpool creation. + Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`, + it sets the creation keyword arguments accordingly. + """ + def __init__(self, parent: "AtariEnvFactory"): self.parent = parent if self.parent.scale: @@ -393,6 +398,7 @@ class AtariEnvFactory(EnvFactoryRegistered): def _transform_task(self, task: str) -> str: task = super()._transform_task(task) + # TODO: Maybe warn user, explain why this is needed return task.replace("NoFrameskip-v4", "-v5") def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 0c07035..5327794 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -270,7 +270,9 @@ class DiscreteEnvironments(Environments): class EnvPoolFactory: - """A factory for the creation of envpool-based vectorized environments.""" + """A factory for the creation of envpool-based vectorized environments, which can be used in conjunction + with :class:`EnvFactoryRegistered`. + """ def _transform_task(self, task: str) -> str: return task @@ -294,7 +296,7 @@ class EnvPoolFactory: mode: EnvMode, seed: int, kwargs: dict, - ) -> BaseVectorEnv | None: + ) -> BaseVectorEnv: import envpool envpool_task = self._transform_task(task) @@ -308,6 +310,8 @@ class EnvPoolFactory: class EnvFactory(ToStringMixin, ABC): + """Main interface for the creation of environments (in various forms).""" + def __init__(self, venv_type: VectorEnvType): """:param venv_type: the type of vectorized environment to use""" self.venv_type = venv_type @@ -346,7 +350,8 @@ class EnvFactory(ToStringMixin, ABC): class EnvFactoryRegistered(EnvFactory): """Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make` - (or via `envpool.make_gymnasium`).""" + (or via `envpool.make_gymnasium`). + """ def __init__( self,