Refactoring, improving class name EnvFactoryGymnasium -> EnvFactoryRegistered
This commit is contained in:
parent
c9cb41bf55
commit
05a8cf4e74
@ -208,7 +208,7 @@ To get started, we need some imports.
|
||||
```python
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryGymnasium,
|
||||
EnvFactoryRegistered,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
||||
|
@ -10,7 +10,7 @@ import numpy as np
|
||||
from gymnasium import Env
|
||||
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryGymnasium,
|
||||
EnvFactoryRegistered,
|
||||
EnvMode,
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
@ -345,7 +345,7 @@ def make_atari_env(
|
||||
return envs.env, envs.train_envs, envs.test_envs
|
||||
|
||||
|
||||
class AtariEnvFactory(EnvFactoryGymnasium):
|
||||
class AtariEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryGymnasium,
|
||||
EnvFactoryRegistered,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
||||
@ -16,7 +16,7 @@ from tianshou.utils.logging import run_main
|
||||
def main():
|
||||
experiment = (
|
||||
DQNExperimentBuilder(
|
||||
EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
|
||||
EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
|
||||
ExperimentConfig(
|
||||
persistence_enabled=False,
|
||||
watch=True,
|
||||
|
@ -4,7 +4,7 @@ import pickle
|
||||
from tianshou.env import VectorEnvNormObs
|
||||
from tianshou.highlevel.env import (
|
||||
ContinuousEnvironments,
|
||||
EnvFactoryGymnasium,
|
||||
EnvFactoryRegistered,
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
)
|
||||
@ -58,7 +58,7 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactoryGymnasium):
|
||||
class MujocoEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self, task: str, seed: int, obs_norm=True):
|
||||
super().__init__(
|
||||
task=task,
|
||||
|
@ -1,14 +1,14 @@
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryGymnasium,
|
||||
EnvFactoryRegistered,
|
||||
VectorEnvType,
|
||||
)
|
||||
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactoryGymnasium):
|
||||
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self):
|
||||
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactoryGymnasium):
|
||||
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self):
|
||||
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
|
@ -344,8 +344,9 @@ class EnvFactory(ToStringMixin, ABC):
|
||||
raise ValueError
|
||||
|
||||
|
||||
class EnvFactoryGymnasium(EnvFactory):
|
||||
"""Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`)."""
|
||||
class EnvFactoryRegistered(EnvFactory):
|
||||
"""Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make`
|
||||
(or via `envpool.make_gymnasium`)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -357,7 +358,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
||||
render_mode_train: str | None = None,
|
||||
render_mode_test: str | None = None,
|
||||
render_mode_watch: str = "human",
|
||||
**kwargs: Any,
|
||||
**make_kwargs: Any,
|
||||
):
|
||||
""":param task: the gymnasium task/environment identifier
|
||||
:param seed: the random seed
|
||||
@ -366,7 +367,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
||||
:param render_mode_train: the render mode to use for training environments
|
||||
:param render_mode_test: the render mode to use for test environments
|
||||
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
||||
:param kwargs: additional keyword arguments to pass on to `gymnasium.make`.
|
||||
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`.
|
||||
If envpool is used, the gymnasium parameters will be appropriately translated for use with
|
||||
`envpool.make_gymnasium`.
|
||||
"""
|
||||
@ -379,7 +380,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
||||
EnvMode.TEST: render_mode_test,
|
||||
EnvMode.WATCH: render_mode_watch,
|
||||
}
|
||||
self.kwargs = kwargs
|
||||
self.make_kwargs = make_kwargs
|
||||
|
||||
def _create_kwargs(self, mode: EnvMode) -> dict:
|
||||
"""Adapts the keyword arguments for the given mode.
|
||||
@ -387,7 +388,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
||||
:param mode: the mode
|
||||
:return: adapted keyword arguments
|
||||
"""
|
||||
kwargs = dict(self.kwargs)
|
||||
kwargs = dict(self.make_kwargs)
|
||||
kwargs["render_mode"] = self.render_modes.get(mode)
|
||||
return kwargs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user