Improve environment factory abstractions in high-level API:
* EnvFactory now uses the creation of a single environment as the basic functionality which the more high-level functions build upon * Introduce enum EnvMode to indicate the purpose for which an env is created, allowing the factory creation process to change its behaviour accordingly * Add EnvFactoryGymnasium to provide direct support for envs that can be created via gymnasium.make - EnvPool is supported via an injectible EnvPoolFactory - Existing EnvFactory implementations are now derived from EnvFactoryGymnasium * Use a separate environment (which uses new EnvMode.WATCH) for watching agent performance after training (instead of using test environments, which the user may want to configure differently)
This commit is contained in:
parent
8188a904af
commit
eaab7b0a4b
@ -7,9 +7,15 @@ from collections import deque
|
|||||||
import cv2
|
import cv2
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from gymnasium import Env
|
||||||
|
|
||||||
from tianshou.env import ShmemVectorEnv
|
from tianshou.env import ShmemVectorEnv
|
||||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
from tianshou.highlevel.env import (
|
||||||
|
EnvFactoryGymnasium,
|
||||||
|
EnvMode,
|
||||||
|
EnvPoolFactory,
|
||||||
|
VectorEnvType,
|
||||||
|
)
|
||||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -282,7 +288,7 @@ class FrameStack(gym.Wrapper):
|
|||||||
|
|
||||||
|
|
||||||
def wrap_deepmind(
|
def wrap_deepmind(
|
||||||
env_id,
|
env: Env,
|
||||||
episode_life=True,
|
episode_life=True,
|
||||||
clip_rewards=True,
|
clip_rewards=True,
|
||||||
frame_stack=4,
|
frame_stack=4,
|
||||||
@ -293,7 +299,7 @@ def wrap_deepmind(
|
|||||||
|
|
||||||
The observation is channel-first: (c, h, w) instead of (h, w, c).
|
The observation is channel-first: (c, h, w) instead of (h, w, c).
|
||||||
|
|
||||||
:param str env_id: the atari environment id.
|
:param env: the Atari environment to wrap.
|
||||||
:param bool episode_life: wrap the episode life wrapper.
|
:param bool episode_life: wrap the episode life wrapper.
|
||||||
:param bool clip_rewards: wrap the reward clipping wrapper.
|
:param bool clip_rewards: wrap the reward clipping wrapper.
|
||||||
:param int frame_stack: wrap the frame stacking wrapper.
|
:param int frame_stack: wrap the frame stacking wrapper.
|
||||||
@ -301,8 +307,6 @@ def wrap_deepmind(
|
|||||||
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
|
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
|
||||||
:return: the wrapped atari environment.
|
:return: the wrapped atari environment.
|
||||||
"""
|
"""
|
||||||
assert "NoFrameskip" in env_id
|
|
||||||
env = gym.make(env_id)
|
|
||||||
env = NoopResetEnv(env, noop_max=30)
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
env = MaxAndSkipEnv(env, skip=4)
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
if episode_life:
|
if episode_life:
|
||||||
@ -351,19 +355,30 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
|||||||
stack_num=kwargs.get("frame_stack", 4),
|
stack_num=kwargs.get("frame_stack", 4),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert "NoFrameskip" in task
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
|
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
|
||||||
)
|
)
|
||||||
env = wrap_deepmind(task, **kwargs)
|
env = wrap_deepmind(task, **kwargs)
|
||||||
train_envs = ShmemVectorEnv(
|
train_envs = ShmemVectorEnv(
|
||||||
[
|
[
|
||||||
lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs)
|
lambda: wrap_deepmind(
|
||||||
|
gym.make(task),
|
||||||
|
episode_life=True,
|
||||||
|
clip_rewards=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
for _ in range(training_num)
|
for _ in range(training_num)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
test_envs = ShmemVectorEnv(
|
test_envs = ShmemVectorEnv(
|
||||||
[
|
[
|
||||||
lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs)
|
lambda: wrap_deepmind(
|
||||||
|
gym.make(task),
|
||||||
|
episode_life=False,
|
||||||
|
clip_rewards=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
for _ in range(test_num)
|
for _ in range(test_num)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -373,23 +388,49 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
|||||||
return env, train_envs, test_envs
|
return env, train_envs, test_envs
|
||||||
|
|
||||||
|
|
||||||
class AtariEnvFactory(EnvFactory):
|
class AtariEnvFactory(EnvFactoryGymnasium):
|
||||||
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
|
def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False):
|
||||||
self.task = task
|
assert "NoFrameskip" in task
|
||||||
self.seed = seed
|
|
||||||
self.frame_stack = frame_stack
|
self.frame_stack = frame_stack
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
super().__init__(
|
||||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
|
task=task,
|
||||||
env, train_envs, test_envs = make_atari_env(
|
seed=seed,
|
||||||
task=self.task,
|
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||||
seed=self.seed,
|
envpool_factory=self.EnvPoolFactory(self),
|
||||||
training_num=num_training_envs,
|
|
||||||
test_num=num_test_envs,
|
|
||||||
scale=self.scale,
|
|
||||||
frame_stack=self.frame_stack,
|
|
||||||
)
|
)
|
||||||
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
|
||||||
|
def create_env(self, mode: EnvMode) -> Env:
|
||||||
|
env = super().create_env(mode)
|
||||||
|
is_train = mode == EnvMode.TRAIN
|
||||||
|
return wrap_deepmind(
|
||||||
|
env,
|
||||||
|
episode_life=is_train,
|
||||||
|
clip_rewards=is_train,
|
||||||
|
frame_stack=self.frame_stack,
|
||||||
|
scale=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
class EnvPoolFactory(EnvPoolFactory):
|
||||||
|
def __init__(self, parent: "AtariEnvFactory"):
|
||||||
|
self.parent = parent
|
||||||
|
if self.parent.scale:
|
||||||
|
warnings.warn(
|
||||||
|
"EnvPool does not include ScaledFloatFrame wrapper, "
|
||||||
|
"please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_task(self, task: str) -> str:
|
||||||
|
task = super()._transform_task(task)
|
||||||
|
return task.replace("NoFrameskip-v4", "-v5")
|
||||||
|
|
||||||
|
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
|
||||||
|
kwargs = super()._transform_kwargs(kwargs, mode)
|
||||||
|
is_train = mode == EnvMode.TRAIN
|
||||||
|
kwargs["reward_clip"] = is_train
|
||||||
|
kwargs["episodic_life"] = is_train
|
||||||
|
kwargs["stack_num"] = self.parent.frame_stack
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
class AtariStopCallback(TrainerStopCallback):
|
class AtariStopCallback(TrainerStopCallback):
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import warnings
|
|
||||||
|
|
||||||
import gymnasium as gym
|
from tianshou.env import VectorEnvNormObs
|
||||||
|
from tianshou.highlevel.env import (
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
ContinuousEnvironments,
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
EnvFactoryGymnasium,
|
||||||
|
EnvPoolFactory,
|
||||||
|
VectorEnvType,
|
||||||
|
)
|
||||||
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||||
from tianshou.highlevel.world import World
|
from tianshou.highlevel.world import World
|
||||||
|
|
||||||
@ -24,25 +26,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
|||||||
|
|
||||||
:return: a tuple of (single env, training envs, test envs).
|
:return: a tuple of (single env, training envs, test envs).
|
||||||
"""
|
"""
|
||||||
if envpool is not None:
|
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
|
||||||
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
|
num_train_envs,
|
||||||
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
|
num_test_envs,
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
"Recommend using envpool (pip install envpool) "
|
|
||||||
"to run Mujoco environments more efficiently.",
|
|
||||||
)
|
)
|
||||||
env = gym.make(task)
|
return envs.env, envs.train_envs, envs.test_envs
|
||||||
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
|
|
||||||
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
||||||
train_envs.seed(seed)
|
|
||||||
test_envs.seed(seed)
|
|
||||||
if obs_norm:
|
|
||||||
# obs norm wrapper
|
|
||||||
train_envs = VectorEnvNormObs(train_envs)
|
|
||||||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
|
||||||
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
|
||||||
return env, train_envs, test_envs
|
|
||||||
|
|
||||||
|
|
||||||
class MujocoEnvObsRmsPersistence(Persistence):
|
class MujocoEnvObsRmsPersistence(Persistence):
|
||||||
@ -68,21 +56,25 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
|||||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactory):
|
class MujocoEnvFactory(EnvFactoryGymnasium):
|
||||||
def __init__(self, task: str, seed: int, obs_norm=True):
|
def __init__(self, task: str, seed: int, obs_norm=True):
|
||||||
self.task = task
|
super().__init__(
|
||||||
self.seed = seed
|
task=task,
|
||||||
|
seed=seed,
|
||||||
|
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||||
|
envpool_factory=EnvPoolFactory(),
|
||||||
|
)
|
||||||
self.obs_norm = obs_norm
|
self.obs_norm = obs_norm
|
||||||
|
|
||||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
||||||
env, train_envs, test_envs = make_mujoco_env(
|
envs = super().create_envs(num_training_envs, num_test_envs)
|
||||||
task=self.task,
|
assert isinstance(envs, ContinuousEnvironments)
|
||||||
seed=self.seed,
|
|
||||||
num_train_envs=num_training_envs,
|
# obs norm wrapper
|
||||||
num_test_envs=num_test_envs,
|
|
||||||
obs_norm=self.obs_norm,
|
|
||||||
)
|
|
||||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
|
||||||
if self.obs_norm:
|
if self.obs_norm:
|
||||||
|
envs.train_envs = VectorEnvNormObs(envs.train_envs)
|
||||||
|
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
|
||||||
|
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||||
|
|
||||||
return envs
|
return envs
|
||||||
|
@ -1,27 +1,14 @@
|
|||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
from tianshou.env import DummyVectorEnv
|
|
||||||
from tianshou.highlevel.env import (
|
from tianshou.highlevel.env import (
|
||||||
ContinuousEnvironments,
|
EnvFactoryGymnasium,
|
||||||
DiscreteEnvironments,
|
VectorEnvType,
|
||||||
EnvFactory,
|
|
||||||
Environments,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteTestEnvFactory(EnvFactory):
|
class DiscreteTestEnvFactory(EnvFactoryGymnasium):
|
||||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
def __init__(self):
|
||||||
task = "CartPole-v0"
|
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||||
env = gym.make(task)
|
|
||||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
||||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
||||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
|
||||||
|
|
||||||
|
|
||||||
class ContinuousTestEnvFactory(EnvFactory):
|
class ContinuousTestEnvFactory(EnvFactoryGymnasium):
|
||||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
def __init__(self):
|
||||||
task = "Pendulum-v1"
|
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||||
env = gym.make(task)
|
|
||||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
|
||||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
||||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypeAlias, cast
|
from typing import Any, TypeAlias, cast
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
import gymnasium.spaces
|
||||||
|
from gymnasium import Env
|
||||||
|
|
||||||
from tianshou.env import (
|
from tianshou.env import (
|
||||||
BaseVectorEnv,
|
BaseVectorEnv,
|
||||||
@ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
|
|
||||||
TObservationShape: TypeAlias = int | Sequence[int]
|
TObservationShape: TypeAlias = int | Sequence[int]
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EnvType(Enum):
|
class EnvType(Enum):
|
||||||
"""Enumeration of environment types."""
|
"""Enumeration of environment types."""
|
||||||
@ -39,6 +44,23 @@ class EnvType(Enum):
|
|||||||
if not self.is_discrete():
|
if not self.is_discrete():
|
||||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_env(env: Env) -> "EnvType":
|
||||||
|
if isinstance(env.action_space, gymnasium.spaces.Discrete):
|
||||||
|
return EnvType.DISCRETE
|
||||||
|
elif isinstance(env.action_space, gymnasium.spaces.Box):
|
||||||
|
return EnvType.CONTINUOUS
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported environment type with action space {env.action_space}")
|
||||||
|
|
||||||
|
|
||||||
|
class EnvMode(Enum):
|
||||||
|
"""Indicates the purpose for which an environment is created."""
|
||||||
|
|
||||||
|
TRAIN = "train"
|
||||||
|
TEST = "test"
|
||||||
|
WATCH = "watch"
|
||||||
|
|
||||||
|
|
||||||
class VectorEnvType(Enum):
|
class VectorEnvType(Enum):
|
||||||
DUMMY = "dummy"
|
DUMMY = "dummy"
|
||||||
@ -65,7 +87,7 @@ class VectorEnvType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Environments(ToStringMixin, ABC):
|
class Environments(ToStringMixin, ABC):
|
||||||
"""Represents (vectorized) environments."""
|
"""Represents (vectorized) environments for a learning process."""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
self.env = env
|
self.env = env
|
||||||
@ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_factory_and_type(
|
def from_factory_and_type(
|
||||||
factory_fn: Callable[[], gym.Env],
|
factory_fn: Callable[[EnvMode], gym.Env],
|
||||||
env_type: EnvType,
|
env_type: EnvType,
|
||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
|
||||||
) -> "Environments":
|
) -> "Environments":
|
||||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||||
|
|
||||||
@ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
|
||||||
if None, use `factory_fn` for all environments (train and test)
|
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
if test_factory_fn is None:
|
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
|
||||||
test_factory_fn = factory_fn
|
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
|
||||||
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
env = factory_fn(EnvMode.TRAIN)
|
||||||
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
|
|
||||||
env = factory_fn()
|
|
||||||
match env_type:
|
match env_type:
|
||||||
case EnvType.CONTINUOUS:
|
case EnvType.CONTINUOUS:
|
||||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||||
@ -153,11 +170,10 @@ class ContinuousEnvironments(Environments):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_factory(
|
def from_factory(
|
||||||
factory_fn: Callable[[], gym.Env],
|
factory_fn: Callable[[EnvMode], gym.Env],
|
||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
|
||||||
) -> "ContinuousEnvironments":
|
) -> "ContinuousEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -165,8 +181,6 @@ class ContinuousEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
|
||||||
if None, use `factory_fn` for all environments (train and test)
|
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -177,7 +191,6 @@ class ContinuousEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
test_factory_fn=test_factory_fn,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -222,11 +235,10 @@ class DiscreteEnvironments(Environments):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_factory(
|
def from_factory(
|
||||||
factory_fn: Callable[[], gym.Env],
|
factory_fn: Callable[[EnvMode], gym.Env],
|
||||||
venv_type: VectorEnvType,
|
venv_type: VectorEnvType,
|
||||||
num_training_envs: int,
|
num_training_envs: int,
|
||||||
num_test_envs: int,
|
num_test_envs: int,
|
||||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
|
||||||
) -> "DiscreteEnvironments":
|
) -> "DiscreteEnvironments":
|
||||||
"""Creates an instance from a factory function that creates a single instance.
|
"""Creates an instance from a factory function that creates a single instance.
|
||||||
|
|
||||||
@ -234,8 +246,6 @@ class DiscreteEnvironments(Environments):
|
|||||||
:param venv_type: the vector environment type to use for parallelization
|
:param venv_type: the vector environment type to use for parallelization
|
||||||
:param num_training_envs: the number of training environments to create
|
:param num_training_envs: the number of training environments to create
|
||||||
:param num_test_envs: the number of test environments to create
|
:param num_test_envs: the number of test environments to create
|
||||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
|
||||||
if None, use `factory_fn` for all environments (train and test)
|
|
||||||
:return: the instance
|
:return: the instance
|
||||||
"""
|
"""
|
||||||
return cast(
|
return cast(
|
||||||
@ -246,7 +256,6 @@ class DiscreteEnvironments(Environments):
|
|||||||
venv_type,
|
venv_type,
|
||||||
num_training_envs,
|
num_training_envs,
|
||||||
num_test_envs,
|
num_test_envs,
|
||||||
test_factory_fn=test_factory_fn,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -260,7 +269,156 @@ class DiscreteEnvironments(Environments):
|
|||||||
return EnvType.DISCRETE
|
return EnvType.DISCRETE
|
||||||
|
|
||||||
|
|
||||||
|
class EnvPoolFactory:
|
||||||
|
"""A factory for the creation of envpool-based vectorized environments."""
|
||||||
|
|
||||||
|
def _transform_task(self, task: str) -> str:
|
||||||
|
return task
|
||||||
|
|
||||||
|
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
|
||||||
|
"""Transforms gymnasium keyword arguments to be envpool-compatible.
|
||||||
|
|
||||||
|
:param kwargs: keyword arguments that would normally be passed to `gymnasium.make`.
|
||||||
|
:param mode: the environment mode
|
||||||
|
:return: the transformed keyword arguments
|
||||||
|
"""
|
||||||
|
kwargs = dict(kwargs)
|
||||||
|
if "render_mode" in kwargs:
|
||||||
|
del kwargs["render_mode"]
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def create_venv(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
num_envs: int,
|
||||||
|
mode: EnvMode,
|
||||||
|
seed: int,
|
||||||
|
kwargs: dict,
|
||||||
|
) -> BaseVectorEnv | None:
|
||||||
|
try:
|
||||||
|
import envpool
|
||||||
|
|
||||||
|
envpool_task = self._transform_task(task)
|
||||||
|
envpool_kwargs = self._transform_kwargs(kwargs, mode)
|
||||||
|
return envpool.make_gymnasium(
|
||||||
|
envpool_task,
|
||||||
|
num_envs=num_envs,
|
||||||
|
seed=seed,
|
||||||
|
**envpool_kwargs,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class EnvFactory(ToStringMixin, ABC):
|
class EnvFactory(ToStringMixin, ABC):
|
||||||
|
def __init__(self, venv_type: VectorEnvType):
|
||||||
|
""":param venv_type: the type of vectorized environment to use"""
|
||||||
|
self.venv_type = venv_type
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
def create_env(self, mode: EnvMode) -> Env:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||||
|
"""Create vectorized environments.
|
||||||
|
|
||||||
|
:param num_envs: the number of environments
|
||||||
|
:param mode: the mode for which to create
|
||||||
|
:return: the vectorized environments
|
||||||
|
"""
|
||||||
|
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
||||||
|
|
||||||
|
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||||
|
"""Create environments for learning.
|
||||||
|
|
||||||
|
:param num_training_envs: the number of training environments
|
||||||
|
:param num_test_envs: the number of test environments
|
||||||
|
:return: the environments
|
||||||
|
"""
|
||||||
|
env = self.create_env(EnvMode.TRAIN)
|
||||||
|
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
|
||||||
|
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
|
||||||
|
match EnvType.from_env(env):
|
||||||
|
case EnvType.DISCRETE:
|
||||||
|
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||||
|
case EnvType.CONTINUOUS:
|
||||||
|
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||||
|
case _:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
class EnvFactoryGymnasium(EnvFactory):
|
||||||
|
"""Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
task: str,
|
||||||
|
seed: int,
|
||||||
|
venv_type: VectorEnvType,
|
||||||
|
envpool_factory: EnvPoolFactory | None = None,
|
||||||
|
render_mode_train: str | None = None,
|
||||||
|
render_mode_test: str | None = None,
|
||||||
|
render_mode_watch: str = "human",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
""":param task: the gymnasium task/environment identifier
|
||||||
|
:param seed: the random seed
|
||||||
|
:param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback.
|
||||||
|
:param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed.
|
||||||
|
If it is not installed, `venv_type` applies as a fallback.
|
||||||
|
:param render_mode_train: the render mode to use for training environments
|
||||||
|
:param render_mode_test: the render mode to use for test environments
|
||||||
|
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
||||||
|
:param kwargs: additional keyword arguments to pass on to `gymnasium.make`.
|
||||||
|
If envpool is used, the gymnasium parameters will be appropriately translated for use with
|
||||||
|
`envpool.make_gymnasium`.
|
||||||
|
"""
|
||||||
|
super().__init__(venv_type)
|
||||||
|
self.task = task
|
||||||
|
self.envpool_factory = envpool_factory
|
||||||
|
self.seed = seed
|
||||||
|
self.render_modes = {
|
||||||
|
EnvMode.TRAIN: render_mode_train,
|
||||||
|
EnvMode.TEST: render_mode_test,
|
||||||
|
EnvMode.WATCH: render_mode_watch,
|
||||||
|
}
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def _create_kwargs(self, mode: EnvMode) -> dict:
|
||||||
|
"""Adapts the keyword arguments for the given mode.
|
||||||
|
|
||||||
|
:param mode: the mode
|
||||||
|
:return: adapted keyword arguments
|
||||||
|
"""
|
||||||
|
kwargs = dict(self.kwargs)
|
||||||
|
kwargs["render_mode"] = self.render_modes.get(mode)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def create_env(self, mode: EnvMode) -> Env:
|
||||||
|
"""Creates a single environment for the given mode.
|
||||||
|
|
||||||
|
:param mode: the mode
|
||||||
|
:return: an environment
|
||||||
|
"""
|
||||||
|
kwargs = self._create_kwargs(mode)
|
||||||
|
return gymnasium.make(self.task, **kwargs)
|
||||||
|
|
||||||
|
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||||
|
if self.envpool_factory is not None:
|
||||||
|
venv = self.envpool_factory.create_venv(
|
||||||
|
self.task,
|
||||||
|
num_envs,
|
||||||
|
mode,
|
||||||
|
self.seed,
|
||||||
|
self._create_kwargs(mode),
|
||||||
|
)
|
||||||
|
if venv is not None:
|
||||||
|
return venv
|
||||||
|
log.debug(
|
||||||
|
f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
venv = super().create_venv(num_envs, mode)
|
||||||
|
venv.seed(self.seed)
|
||||||
|
return venv
|
||||||
|
@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
|
|||||||
TRPOAgentFactory,
|
TRPOAgentFactory,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import EnvFactory
|
from tianshou.highlevel.env import EnvFactory, EnvMode
|
||||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
ActorFactory,
|
ActorFactory,
|
||||||
@ -293,7 +293,7 @@ class Experiment(ToStringMixin):
|
|||||||
self._watch_agent(
|
self._watch_agent(
|
||||||
self.config.watch_num_episodes,
|
self.config.watch_num_episodes,
|
||||||
policy,
|
policy,
|
||||||
test_collector,
|
self.env_factory,
|
||||||
self.config.watch_render,
|
self.config.watch_render,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -303,15 +303,18 @@ class Experiment(ToStringMixin):
|
|||||||
def _watch_agent(
|
def _watch_agent(
|
||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
test_collector: Collector,
|
env_factory: EnvFactory,
|
||||||
render: float,
|
render: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
test_collector.reset()
|
env = env_factory.create_env(EnvMode.WATCH)
|
||||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=num_episodes, render=render)
|
||||||
assert result.returns_stat is not None # for mypy
|
assert result.returns_stat is not None # for mypy
|
||||||
assert result.lens_stat is not None # for mypy
|
assert result.lens_stat is not None # for mypy
|
||||||
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")
|
log.info(
|
||||||
|
f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExperimentBuilder:
|
class ExperimentBuilder:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user