Improve environment factory abstractions in high-level API:
* EnvFactory now uses the creation of a single environment as the basic functionality which the more high-level functions build upon * Introduce enum EnvMode to indicate the purpose for which an env is created, allowing the factory creation process to change its behaviour accordingly * Add EnvFactoryGymnasium to provide direct support for envs that can be created via gymnasium.make - EnvPool is supported via an injectible EnvPoolFactory - Existing EnvFactory implementations are now derived from EnvFactoryGymnasium * Use a separate environment (which uses new EnvMode.WATCH) for watching agent performance after training (instead of using test environments, which the user may want to configure differently)
This commit is contained in:
parent
8188a904af
commit
eaab7b0a4b
@ -7,9 +7,15 @@ from collections import deque
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import Env
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryGymnasium,
|
||||
EnvMode,
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||
|
||||
try:
|
||||
@ -282,7 +288,7 @@ class FrameStack(gym.Wrapper):
|
||||
|
||||
|
||||
def wrap_deepmind(
|
||||
env_id,
|
||||
env: Env,
|
||||
episode_life=True,
|
||||
clip_rewards=True,
|
||||
frame_stack=4,
|
||||
@ -293,7 +299,7 @@ def wrap_deepmind(
|
||||
|
||||
The observation is channel-first: (c, h, w) instead of (h, w, c).
|
||||
|
||||
:param str env_id: the atari environment id.
|
||||
:param env: the Atari environment to wrap.
|
||||
:param bool episode_life: wrap the episode life wrapper.
|
||||
:param bool clip_rewards: wrap the reward clipping wrapper.
|
||||
:param int frame_stack: wrap the frame stacking wrapper.
|
||||
@ -301,8 +307,6 @@ def wrap_deepmind(
|
||||
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
|
||||
:return: the wrapped atari environment.
|
||||
"""
|
||||
assert "NoFrameskip" in env_id
|
||||
env = gym.make(env_id)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if episode_life:
|
||||
@ -351,19 +355,30 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
stack_num=kwargs.get("frame_stack", 4),
|
||||
)
|
||||
else:
|
||||
assert "NoFrameskip" in task
|
||||
warnings.warn(
|
||||
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
|
||||
)
|
||||
env = wrap_deepmind(task, **kwargs)
|
||||
train_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs)
|
||||
lambda: wrap_deepmind(
|
||||
gym.make(task),
|
||||
episode_life=True,
|
||||
clip_rewards=True,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(training_num)
|
||||
],
|
||||
)
|
||||
test_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs)
|
||||
lambda: wrap_deepmind(
|
||||
gym.make(task),
|
||||
episode_life=False,
|
||||
clip_rewards=False,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(test_num)
|
||||
],
|
||||
)
|
||||
@ -373,23 +388,49 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
return env, train_envs, test_envs
|
||||
|
||||
|
||||
class AtariEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
|
||||
self.task = task
|
||||
self.seed = seed
|
||||
class AtariEnvFactory(EnvFactoryGymnasium):
|
||||
def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False):
|
||||
assert "NoFrameskip" in task
|
||||
self.frame_stack = frame_stack
|
||||
self.scale = scale
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
training_num=num_training_envs,
|
||||
test_num=num_test_envs,
|
||||
scale=self.scale,
|
||||
frame_stack=self.frame_stack,
|
||||
super().__init__(
|
||||
task=task,
|
||||
seed=seed,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
envpool_factory=self.EnvPoolFactory(self),
|
||||
)
|
||||
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
|
||||
def create_env(self, mode: EnvMode) -> Env:
|
||||
env = super().create_env(mode)
|
||||
is_train = mode == EnvMode.TRAIN
|
||||
return wrap_deepmind(
|
||||
env,
|
||||
episode_life=is_train,
|
||||
clip_rewards=is_train,
|
||||
frame_stack=self.frame_stack,
|
||||
scale=self.scale,
|
||||
)
|
||||
|
||||
class EnvPoolFactory(EnvPoolFactory):
|
||||
def __init__(self, parent: "AtariEnvFactory"):
|
||||
self.parent = parent
|
||||
if self.parent.scale:
|
||||
warnings.warn(
|
||||
"EnvPool does not include ScaledFloatFrame wrapper, "
|
||||
"please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)",
|
||||
)
|
||||
|
||||
def _transform_task(self, task: str) -> str:
|
||||
task = super()._transform_task(task)
|
||||
return task.replace("NoFrameskip-v4", "-v5")
|
||||
|
||||
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
|
||||
kwargs = super()._transform_kwargs(kwargs, mode)
|
||||
is_train = mode == EnvMode.TRAIN
|
||||
kwargs["reward_clip"] = is_train
|
||||
kwargs["episodic_life"] = is_train
|
||||
kwargs["stack_num"] = self.parent.frame_stack
|
||||
return kwargs
|
||||
|
||||
|
||||
class AtariStopCallback(TrainerStopCallback):
|
||||
|
@ -1,11 +1,13 @@
|
||||
import logging
|
||||
import pickle
|
||||
import warnings
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
from tianshou.env import VectorEnvNormObs
|
||||
from tianshou.highlevel.env import (
|
||||
ContinuousEnvironments,
|
||||
EnvFactoryGymnasium,
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||
from tianshou.highlevel.world import World
|
||||
|
||||
@ -24,25 +26,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
||||
|
||||
:return: a tuple of (single env, training envs, test envs).
|
||||
"""
|
||||
if envpool is not None:
|
||||
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
|
||||
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Recommend using envpool (pip install envpool) "
|
||||
"to run Mujoco environments more efficiently.",
|
||||
)
|
||||
env = gym.make(task)
|
||||
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
|
||||
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
train_envs.seed(seed)
|
||||
test_envs.seed(seed)
|
||||
if obs_norm:
|
||||
# obs norm wrapper
|
||||
train_envs = VectorEnvNormObs(train_envs)
|
||||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
||||
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
||||
return env, train_envs, test_envs
|
||||
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
|
||||
num_train_envs,
|
||||
num_test_envs,
|
||||
)
|
||||
return envs.env, envs.train_envs, envs.test_envs
|
||||
|
||||
|
||||
class MujocoEnvObsRmsPersistence(Persistence):
|
||||
@ -68,21 +56,25 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactory):
|
||||
class MujocoEnvFactory(EnvFactoryGymnasium):
|
||||
def __init__(self, task: str, seed: int, obs_norm=True):
|
||||
self.task = task
|
||||
self.seed = seed
|
||||
super().__init__(
|
||||
task=task,
|
||||
seed=seed,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
envpool_factory=EnvPoolFactory(),
|
||||
)
|
||||
self.obs_norm = obs_norm
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
num_train_envs=num_training_envs,
|
||||
num_test_envs=num_test_envs,
|
||||
obs_norm=self.obs_norm,
|
||||
)
|
||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
envs = super().create_envs(num_training_envs, num_test_envs)
|
||||
assert isinstance(envs, ContinuousEnvironments)
|
||||
|
||||
# obs norm wrapper
|
||||
if self.obs_norm:
|
||||
envs.train_envs = VectorEnvNormObs(envs.train_envs)
|
||||
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
|
||||
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||
|
||||
return envs
|
||||
|
@ -1,27 +1,14 @@
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.highlevel.env import (
|
||||
ContinuousEnvironments,
|
||||
DiscreteEnvironments,
|
||||
EnvFactory,
|
||||
Environments,
|
||||
EnvFactoryGymnasium,
|
||||
VectorEnvType,
|
||||
)
|
||||
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactory):
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
task = "CartPole-v0"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
class DiscreteTestEnvFactory(EnvFactoryGymnasium):
|
||||
def __init__(self):
|
||||
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactory):
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
task = "Pendulum-v1"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
class ContinuousTestEnvFactory(EnvFactoryGymnasium):
|
||||
def __init__(self):
|
||||
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
|
@ -1,9 +1,12 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import gymnasium.spaces
|
||||
from gymnasium import Env
|
||||
|
||||
from tianshou.env import (
|
||||
BaseVectorEnv,
|
||||
@ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TObservationShape: TypeAlias = int | Sequence[int]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
"""Enumeration of environment types."""
|
||||
@ -39,6 +44,23 @@ class EnvType(Enum):
|
||||
if not self.is_discrete():
|
||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||
|
||||
@staticmethod
|
||||
def from_env(env: Env) -> "EnvType":
|
||||
if isinstance(env.action_space, gymnasium.spaces.Discrete):
|
||||
return EnvType.DISCRETE
|
||||
elif isinstance(env.action_space, gymnasium.spaces.Box):
|
||||
return EnvType.CONTINUOUS
|
||||
else:
|
||||
raise Exception(f"Unsupported environment type with action space {env.action_space}")
|
||||
|
||||
|
||||
class EnvMode(Enum):
|
||||
"""Indicates the purpose for which an environment is created."""
|
||||
|
||||
TRAIN = "train"
|
||||
TEST = "test"
|
||||
WATCH = "watch"
|
||||
|
||||
|
||||
class VectorEnvType(Enum):
|
||||
DUMMY = "dummy"
|
||||
@ -65,7 +87,7 @@ class VectorEnvType(Enum):
|
||||
|
||||
|
||||
class Environments(ToStringMixin, ABC):
|
||||
"""Represents (vectorized) environments."""
|
||||
"""Represents (vectorized) environments for a learning process."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
@ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC):
|
||||
|
||||
@staticmethod
|
||||
def from_factory_and_type(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
factory_fn: Callable[[EnvMode], gym.Env],
|
||||
env_type: EnvType,
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "Environments":
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||
|
||||
@ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
if test_factory_fn is None:
|
||||
test_factory_fn = factory_fn
|
||||
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
||||
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
|
||||
env = factory_fn()
|
||||
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
|
||||
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
|
||||
env = factory_fn(EnvMode.TRAIN)
|
||||
match env_type:
|
||||
case EnvType.CONTINUOUS:
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
@ -153,11 +170,10 @@ class ContinuousEnvironments(Environments):
|
||||
|
||||
@staticmethod
|
||||
def from_factory(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
factory_fn: Callable[[EnvMode], gym.Env],
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "ContinuousEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -165,8 +181,6 @@ class ContinuousEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
@ -177,7 +191,6 @@ class ContinuousEnvironments(Environments):
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
test_factory_fn=test_factory_fn,
|
||||
),
|
||||
)
|
||||
|
||||
@ -222,11 +235,10 @@ class DiscreteEnvironments(Environments):
|
||||
|
||||
@staticmethod
|
||||
def from_factory(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
factory_fn: Callable[[EnvMode], gym.Env],
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "DiscreteEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -234,8 +246,6 @@ class DiscreteEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
@ -246,7 +256,6 @@ class DiscreteEnvironments(Environments):
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
test_factory_fn=test_factory_fn,
|
||||
),
|
||||
)
|
||||
|
||||
@ -260,7 +269,156 @@ class DiscreteEnvironments(Environments):
|
||||
return EnvType.DISCRETE
|
||||
|
||||
|
||||
class EnvPoolFactory:
|
||||
"""A factory for the creation of envpool-based vectorized environments."""
|
||||
|
||||
def _transform_task(self, task: str) -> str:
|
||||
return task
|
||||
|
||||
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
|
||||
"""Transforms gymnasium keyword arguments to be envpool-compatible.
|
||||
|
||||
:param kwargs: keyword arguments that would normally be passed to `gymnasium.make`.
|
||||
:param mode: the environment mode
|
||||
:return: the transformed keyword arguments
|
||||
"""
|
||||
kwargs = dict(kwargs)
|
||||
if "render_mode" in kwargs:
|
||||
del kwargs["render_mode"]
|
||||
return kwargs
|
||||
|
||||
def create_venv(
|
||||
self,
|
||||
task: str,
|
||||
num_envs: int,
|
||||
mode: EnvMode,
|
||||
seed: int,
|
||||
kwargs: dict,
|
||||
) -> BaseVectorEnv | None:
|
||||
try:
|
||||
import envpool
|
||||
|
||||
envpool_task = self._transform_task(task)
|
||||
envpool_kwargs = self._transform_kwargs(kwargs, mode)
|
||||
return envpool.make_gymnasium(
|
||||
envpool_task,
|
||||
num_envs=num_envs,
|
||||
seed=seed,
|
||||
**envpool_kwargs,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
class EnvFactory(ToStringMixin, ABC):
|
||||
def __init__(self, venv_type: VectorEnvType):
|
||||
""":param venv_type: the type of vectorized environment to use"""
|
||||
self.venv_type = venv_type
|
||||
|
||||
@abstractmethod
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
def create_env(self, mode: EnvMode) -> Env:
|
||||
pass
|
||||
|
||||
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||
"""Create vectorized environments.
|
||||
|
||||
:param num_envs: the number of environments
|
||||
:param mode: the mode for which to create
|
||||
:return: the vectorized environments
|
||||
"""
|
||||
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
"""Create environments for learning.
|
||||
|
||||
:param num_training_envs: the number of training environments
|
||||
:param num_test_envs: the number of test environments
|
||||
:return: the environments
|
||||
"""
|
||||
env = self.create_env(EnvMode.TRAIN)
|
||||
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
|
||||
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
|
||||
match EnvType.from_env(env):
|
||||
case EnvType.DISCRETE:
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
case EnvType.CONTINUOUS:
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
case _:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class EnvFactoryGymnasium(EnvFactory):
|
||||
"""Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: str,
|
||||
seed: int,
|
||||
venv_type: VectorEnvType,
|
||||
envpool_factory: EnvPoolFactory | None = None,
|
||||
render_mode_train: str | None = None,
|
||||
render_mode_test: str | None = None,
|
||||
render_mode_watch: str = "human",
|
||||
**kwargs: Any,
|
||||
):
|
||||
""":param task: the gymnasium task/environment identifier
|
||||
:param seed: the random seed
|
||||
:param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback.
|
||||
:param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed.
|
||||
If it is not installed, `venv_type` applies as a fallback.
|
||||
:param render_mode_train: the render mode to use for training environments
|
||||
:param render_mode_test: the render mode to use for test environments
|
||||
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
||||
:param kwargs: additional keyword arguments to pass on to `gymnasium.make`.
|
||||
If envpool is used, the gymnasium parameters will be appropriately translated for use with
|
||||
`envpool.make_gymnasium`.
|
||||
"""
|
||||
super().__init__(venv_type)
|
||||
self.task = task
|
||||
self.envpool_factory = envpool_factory
|
||||
self.seed = seed
|
||||
self.render_modes = {
|
||||
EnvMode.TRAIN: render_mode_train,
|
||||
EnvMode.TEST: render_mode_test,
|
||||
EnvMode.WATCH: render_mode_watch,
|
||||
}
|
||||
self.kwargs = kwargs
|
||||
|
||||
def _create_kwargs(self, mode: EnvMode) -> dict:
|
||||
"""Adapts the keyword arguments for the given mode.
|
||||
|
||||
:param mode: the mode
|
||||
:return: adapted keyword arguments
|
||||
"""
|
||||
kwargs = dict(self.kwargs)
|
||||
kwargs["render_mode"] = self.render_modes.get(mode)
|
||||
return kwargs
|
||||
|
||||
def create_env(self, mode: EnvMode) -> Env:
|
||||
"""Creates a single environment for the given mode.
|
||||
|
||||
:param mode: the mode
|
||||
:return: an environment
|
||||
"""
|
||||
kwargs = self._create_kwargs(mode)
|
||||
return gymnasium.make(self.task, **kwargs)
|
||||
|
||||
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||
if self.envpool_factory is not None:
|
||||
venv = self.envpool_factory.create_venv(
|
||||
self.task,
|
||||
num_envs,
|
||||
mode,
|
||||
self.seed,
|
||||
self._create_kwargs(mode),
|
||||
)
|
||||
if venv is not None:
|
||||
return venv
|
||||
log.debug(
|
||||
f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}",
|
||||
)
|
||||
|
||||
venv = super().create_venv(num_envs, mode)
|
||||
venv.seed(self.seed)
|
||||
return venv
|
||||
|
@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
|
||||
TRPOAgentFactory,
|
||||
)
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.env import EnvFactory, EnvMode
|
||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
@ -293,7 +293,7 @@ class Experiment(ToStringMixin):
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
self.env_factory,
|
||||
self.config.watch_render,
|
||||
)
|
||||
|
||||
@ -303,15 +303,18 @@ class Experiment(ToStringMixin):
|
||||
def _watch_agent(
|
||||
num_episodes: int,
|
||||
policy: BasePolicy,
|
||||
test_collector: Collector,
|
||||
env_factory: EnvFactory,
|
||||
render: float,
|
||||
) -> None:
|
||||
policy.eval()
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||
env = env_factory.create_env(EnvMode.WATCH)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=num_episodes, render=render)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
assert result.lens_stat is not None # for mypy
|
||||
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")
|
||||
log.info(
|
||||
f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}",
|
||||
)
|
||||
|
||||
|
||||
class ExperimentBuilder:
|
||||
|
Loading…
x
Reference in New Issue
Block a user