Refactoring, improving class name EnvFactoryGymnasium -> EnvFactoryRegistered

This commit is contained in:
Dominik Jain 2024-01-16 12:22:07 +01:00
parent c9cb41bf55
commit 05a8cf4e74
6 changed files with 17 additions and 16 deletions

View File

@ -208,7 +208,7 @@ To get started, we need some imports.
```python
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import (
EnvFactoryGymnasium,
EnvFactoryRegistered,
VectorEnvType,
)
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig

View File

@ -10,7 +10,7 @@ import numpy as np
from gymnasium import Env
from tianshou.highlevel.env import (
EnvFactoryGymnasium,
EnvFactoryRegistered,
EnvMode,
EnvPoolFactory,
VectorEnvType,
@ -345,7 +345,7 @@ def make_atari_env(
return envs.env, envs.train_envs, envs.test_envs
class AtariEnvFactory(EnvFactoryGymnasium):
class AtariEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,

View File

@ -1,6 +1,6 @@
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import (
EnvFactoryGymnasium,
EnvFactoryRegistered,
VectorEnvType,
)
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
@ -16,7 +16,7 @@ from tianshou.utils.logging import run_main
def main():
experiment = (
DQNExperimentBuilder(
EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
ExperimentConfig(
persistence_enabled=False,
watch=True,

View File

@ -4,7 +4,7 @@ import pickle
from tianshou.env import VectorEnvNormObs
from tianshou.highlevel.env import (
ContinuousEnvironments,
EnvFactoryGymnasium,
EnvFactoryRegistered,
EnvPoolFactory,
VectorEnvType,
)
@ -58,7 +58,7 @@ class MujocoEnvObsRmsPersistence(Persistence):
world.envs.test_envs.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactoryGymnasium):
class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(self, task: str, seed: int, obs_norm=True):
super().__init__(
task=task,

View File

@ -1,14 +1,14 @@
from tianshou.highlevel.env import (
EnvFactoryGymnasium,
EnvFactoryRegistered,
VectorEnvType,
)
class DiscreteTestEnvFactory(EnvFactoryGymnasium):
class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self):
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
class ContinuousTestEnvFactory(EnvFactoryGymnasium):
class ContinuousTestEnvFactory(EnvFactoryRegistered):
def __init__(self):
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)

View File

@ -344,8 +344,9 @@ class EnvFactory(ToStringMixin, ABC):
raise ValueError
class EnvFactoryGymnasium(EnvFactory):
"""Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`)."""
class EnvFactoryRegistered(EnvFactory):
"""Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make`
(or via `envpool.make_gymnasium`)."""
def __init__(
self,
@ -357,7 +358,7 @@ class EnvFactoryGymnasium(EnvFactory):
render_mode_train: str | None = None,
render_mode_test: str | None = None,
render_mode_watch: str = "human",
**kwargs: Any,
**make_kwargs: Any,
):
""":param task: the gymnasium task/environment identifier
:param seed: the random seed
@ -366,7 +367,7 @@ class EnvFactoryGymnasium(EnvFactory):
:param render_mode_train: the render mode to use for training environments
:param render_mode_test: the render mode to use for test environments
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
:param kwargs: additional keyword arguments to pass on to `gymnasium.make`.
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`.
If envpool is used, the gymnasium parameters will be appropriately translated for use with
`envpool.make_gymnasium`.
"""
@ -379,7 +380,7 @@ class EnvFactoryGymnasium(EnvFactory):
EnvMode.TEST: render_mode_test,
EnvMode.WATCH: render_mode_watch,
}
self.kwargs = kwargs
self.make_kwargs = make_kwargs
def _create_kwargs(self, mode: EnvMode) -> dict:
"""Adapts the keyword arguments for the given mode.
@ -387,7 +388,7 @@ class EnvFactoryGymnasium(EnvFactory):
:param mode: the mode
:return: adapted keyword arguments
"""
kwargs = dict(self.kwargs)
kwargs = dict(self.make_kwargs)
kwargs["render_mode"] = self.render_modes.get(mode)
return kwargs