Refactoring, improving class name EnvFactoryGymnasium -> EnvFactoryRegistered
This commit is contained in:
parent
c9cb41bf55
commit
05a8cf4e74
@ -208,7 +208,7 @@ To get started, we need some imports.
|
|||||||
```python
|
```python
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
EnvFactoryGymnasium,
|
EnvFactoryRegistered,
|
||||||
VectorEnvType,
|
VectorEnvType,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
||||||
|
@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from gymnasium import Env
|
from gymnasium import Env
|
||||||
|
|
||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
EnvFactoryGymnasium,
|
EnvFactoryRegistered,
|
||||||
EnvMode,
|
EnvMode,
|
||||||
EnvPoolFactory,
|
EnvPoolFactory,
|
||||||
VectorEnvType,
|
VectorEnvType,
|
||||||
@ -345,7 +345,7 @@ def make_atari_env(
|
|||||||
return envs.env, envs.train_envs, envs.test_envs
|
return envs.env, envs.train_envs, envs.test_envs
|
||||||
|
|
||||||
|
|
||||||
class AtariEnvFactory(EnvFactoryGymnasium):
|
class AtariEnvFactory(EnvFactoryRegistered):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
EnvFactoryGymnasium,
|
EnvFactoryRegistered,
|
||||||
VectorEnvType,
|
VectorEnvType,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
||||||
@ -16,7 +16,7 @@ from tianshou.utils.logging import run_main
|
|||||||
def main():
|
def main():
|
||||||
experiment = (
|
experiment = (
|
||||||
DQNExperimentBuilder(
|
DQNExperimentBuilder(
|
||||||
EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
|
EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
|
||||||
ExperimentConfig(
|
ExperimentConfig(
|
||||||
persistence_enabled=False,
|
persistence_enabled=False,
|
||||||
watch=True,
|
watch=True,
|
||||||
|
@ -4,7 +4,7 @@ import pickle
|
|||||||
from tianshou.env import VectorEnvNormObs
|
from tianshou.env import VectorEnvNormObs
|
||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
ContinuousEnvironments,
|
ContinuousEnvironments,
|
||||||
EnvFactoryGymnasium,
|
EnvFactoryRegistered,
|
||||||
EnvPoolFactory,
|
EnvPoolFactory,
|
||||||
VectorEnvType,
|
VectorEnvType,
|
||||||
)
|
)
|
||||||
@ -58,7 +58,7 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
|||||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactoryGymnasium):
|
class MujocoEnvFactory(EnvFactoryRegistered):
|
||||||
def __init__(self, task: str, seed: int, obs_norm=True):
|
def __init__(self, task: str, seed: int, obs_norm=True):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task=task,
|
task=task,
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
EnvFactoryGymnasium,
|
EnvFactoryRegistered,
|
||||||
VectorEnvType,
|
VectorEnvType,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteTestEnvFactory(EnvFactoryGymnasium):
|
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
|
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||||
|
|
||||||
|
|
||||||
class ContinuousTestEnvFactory(EnvFactoryGymnasium):
|
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
|
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||||
|
@ -344,8 +344,9 @@ class EnvFactory(ToStringMixin, ABC):
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
class EnvFactoryGymnasium(EnvFactory):
|
class EnvFactoryRegistered(EnvFactory):
|
||||||
"""Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`)."""
|
"""Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make`
|
||||||
|
(or via `envpool.make_gymnasium`)."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -357,7 +358,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
|||||||
render_mode_train: str | None = None,
|
render_mode_train: str | None = None,
|
||||||
render_mode_test: str | None = None,
|
render_mode_test: str | None = None,
|
||||||
render_mode_watch: str = "human",
|
render_mode_watch: str = "human",
|
||||||
**kwargs: Any,
|
**make_kwargs: Any,
|
||||||
):
|
):
|
||||||
""":param task: the gymnasium task/environment identifier
|
""":param task: the gymnasium task/environment identifier
|
||||||
:param seed: the random seed
|
:param seed: the random seed
|
||||||
@ -366,7 +367,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
|||||||
:param render_mode_train: the render mode to use for training environments
|
:param render_mode_train: the render mode to use for training environments
|
||||||
:param render_mode_test: the render mode to use for test environments
|
:param render_mode_test: the render mode to use for test environments
|
||||||
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
||||||
:param kwargs: additional keyword arguments to pass on to `gymnasium.make`.
|
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`.
|
||||||
If envpool is used, the gymnasium parameters will be appropriately translated for use with
|
If envpool is used, the gymnasium parameters will be appropriately translated for use with
|
||||||
`envpool.make_gymnasium`.
|
`envpool.make_gymnasium`.
|
||||||
"""
|
"""
|
||||||
@ -379,7 +380,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
|||||||
EnvMode.TEST: render_mode_test,
|
EnvMode.TEST: render_mode_test,
|
||||||
EnvMode.WATCH: render_mode_watch,
|
EnvMode.WATCH: render_mode_watch,
|
||||||
}
|
}
|
||||||
self.kwargs = kwargs
|
self.make_kwargs = make_kwargs
|
||||||
|
|
||||||
def _create_kwargs(self, mode: EnvMode) -> dict:
|
def _create_kwargs(self, mode: EnvMode) -> dict:
|
||||||
"""Adapts the keyword arguments for the given mode.
|
"""Adapts the keyword arguments for the given mode.
|
||||||
@ -387,7 +388,7 @@ class EnvFactoryGymnasium(EnvFactory):
|
|||||||
:param mode: the mode
|
:param mode: the mode
|
||||||
:return: adapted keyword arguments
|
:return: adapted keyword arguments
|
||||||
"""
|
"""
|
||||||
kwargs = dict(self.kwargs)
|
kwargs = dict(self.make_kwargs)
|
||||||
kwargs["render_mode"] = self.render_modes.get(mode)
|
kwargs["render_mode"] = self.render_modes.get(mode)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user