diff --git a/README.md b/README.md index d309f2b..ddbb6d9 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ To get started, we need some imports. ```python from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ( - EnvFactoryGymnasium, + EnvFactoryRegistered, VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 6e6c4e5..e8023b7 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -10,7 +10,7 @@ import numpy as np from gymnasium import Env from tianshou.highlevel.env import ( - EnvFactoryGymnasium, + EnvFactoryRegistered, EnvMode, EnvPoolFactory, VectorEnvType, @@ -345,7 +345,7 @@ def make_atari_env( return envs.env, envs.train_envs, envs.test_envs -class AtariEnvFactory(EnvFactoryGymnasium): +class AtariEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index c6092f2..e0f4ca5 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -1,6 +1,6 @@ from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ( - EnvFactoryGymnasium, + EnvFactoryRegistered, VectorEnvType, ) from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig @@ -16,7 +16,7 @@ from tianshou.utils.logging import run_main def main(): experiment = ( DQNExperimentBuilder( - EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY), + EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY), ExperimentConfig( persistence_enabled=False, watch=True, diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index dd752e9..081e41e 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -4,7 +4,7 @@ import pickle from tianshou.env import VectorEnvNormObs from tianshou.highlevel.env import ( ContinuousEnvironments, - EnvFactoryGymnasium, + EnvFactoryRegistered, EnvPoolFactory, VectorEnvType, ) @@ -58,7 +58,7 @@ class MujocoEnvObsRmsPersistence(Persistence): world.envs.test_envs.set_obs_rms(obs_rms) -class MujocoEnvFactory(EnvFactoryGymnasium): +class MujocoEnvFactory(EnvFactoryRegistered): def __init__(self, task: str, seed: int, obs_norm=True): super().__init__( task=task, diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 760bb5b..4d09ce8 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -1,14 +1,14 @@ from tianshou.highlevel.env import ( - EnvFactoryGymnasium, + EnvFactoryRegistered, VectorEnvType, ) -class DiscreteTestEnvFactory(EnvFactoryGymnasium): +class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self): super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY) -class ContinuousTestEnvFactory(EnvFactoryGymnasium): +class ContinuousTestEnvFactory(EnvFactoryRegistered): def __init__(self): super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 5298c2d..0c07035 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -344,8 +344,9 @@ class EnvFactory(ToStringMixin, ABC): raise ValueError -class EnvFactoryGymnasium(EnvFactory): - """Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`).""" +class EnvFactoryRegistered(EnvFactory): + """Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make` + (or via `envpool.make_gymnasium`).""" def __init__( self, @@ -357,7 +358,7 @@ class EnvFactoryGymnasium(EnvFactory): render_mode_train: str | None = None, render_mode_test: str | None = None, render_mode_watch: str = "human", - **kwargs: Any, + **make_kwargs: Any, ): """:param task: the gymnasium task/environment identifier :param seed: the random seed @@ -366,7 +367,7 @@ class EnvFactoryGymnasium(EnvFactory): :param render_mode_train: the render mode to use for training environments :param render_mode_test: the render mode to use for test environments :param render_mode_watch: the render mode to use for environments that are used to watch agent performance - :param kwargs: additional keyword arguments to pass on to `gymnasium.make`. + :param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. If envpool is used, the gymnasium parameters will be appropriately translated for use with `envpool.make_gymnasium`. """ @@ -379,7 +380,7 @@ class EnvFactoryGymnasium(EnvFactory): EnvMode.TEST: render_mode_test, EnvMode.WATCH: render_mode_watch, } - self.kwargs = kwargs + self.make_kwargs = make_kwargs def _create_kwargs(self, mode: EnvMode) -> dict: """Adapts the keyword arguments for the given mode. @@ -387,7 +388,7 @@ class EnvFactoryGymnasium(EnvFactory): :param mode: the mode :return: adapted keyword arguments """ - kwargs = dict(self.kwargs) + kwargs = dict(self.make_kwargs) kwargs["render_mode"] = self.render_modes.get(mode) return kwargs