Improve environment factory abstractions in high-level API:

* EnvFactory now uses the creation of a single environment as
   the basic functionality which the more high-level functions build
   upon
 * Introduce enum EnvMode to indicate the purpose for which an env
   is created, allowing the factory creation process to change its
   behaviour accordingly
 * Add EnvFactoryGymnasium to provide direct support for envs that
   can be created via gymnasium.make
     - EnvPool is supported via an injectible EnvPoolFactory
     - Existing EnvFactory implementations are now derived from
       EnvFactoryGymnasium
 * Use a separate environment (which uses new EnvMode.WATCH) for
   watching agent performance after training (instead of using test
   environments, which the user may want to configure differently)
This commit is contained in:
Dominik Jain 2024-01-10 15:37:58 +01:00
parent 8188a904af
commit eaab7b0a4b
5 changed files with 285 additions and 104 deletions

View File

@ -7,9 +7,15 @@ from collections import deque
import cv2
import gymnasium as gym
import numpy as np
from gymnasium import Env
from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
from tianshou.highlevel.env import (
EnvFactoryGymnasium,
EnvMode,
EnvPoolFactory,
VectorEnvType,
)
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
try:
@ -282,7 +288,7 @@ class FrameStack(gym.Wrapper):
def wrap_deepmind(
env_id,
env: Env,
episode_life=True,
clip_rewards=True,
frame_stack=4,
@ -293,7 +299,7 @@ def wrap_deepmind(
The observation is channel-first: (c, h, w) instead of (h, w, c).
:param str env_id: the atari environment id.
:param env: the Atari environment to wrap.
:param bool episode_life: wrap the episode life wrapper.
:param bool clip_rewards: wrap the reward clipping wrapper.
:param int frame_stack: wrap the frame stacking wrapper.
@ -301,8 +307,6 @@ def wrap_deepmind(
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
:return: the wrapped atari environment.
"""
assert "NoFrameskip" in env_id
env = gym.make(env_id)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if episode_life:
@ -351,19 +355,30 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
stack_num=kwargs.get("frame_stack", 4),
)
else:
assert "NoFrameskip" in task
warnings.warn(
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
)
env = wrap_deepmind(task, **kwargs)
train_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs)
lambda: wrap_deepmind(
gym.make(task),
episode_life=True,
clip_rewards=True,
**kwargs,
)
for _ in range(training_num)
],
)
test_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs)
lambda: wrap_deepmind(
gym.make(task),
episode_life=False,
clip_rewards=False,
**kwargs,
)
for _ in range(test_num)
],
)
@ -373,23 +388,49 @@ def make_atari_env(task, seed, training_num, test_num, **kwargs):
return env, train_envs, test_envs
class AtariEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
self.task = task
self.seed = seed
class AtariEnvFactory(EnvFactoryGymnasium):
def __init__(self, task: str, seed: int, frame_stack: int, scale: bool = False):
assert "NoFrameskip" in task
self.frame_stack = frame_stack
self.scale = scale
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env(
task=self.task,
seed=self.seed,
training_num=num_training_envs,
test_num=num_test_envs,
scale=self.scale,
frame_stack=self.frame_stack,
super().__init__(
task=task,
seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=self.EnvPoolFactory(self),
)
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
def create_env(self, mode: EnvMode) -> Env:
env = super().create_env(mode)
is_train = mode == EnvMode.TRAIN
return wrap_deepmind(
env,
episode_life=is_train,
clip_rewards=is_train,
frame_stack=self.frame_stack,
scale=self.scale,
)
class EnvPoolFactory(EnvPoolFactory):
def __init__(self, parent: "AtariEnvFactory"):
self.parent = parent
if self.parent.scale:
warnings.warn(
"EnvPool does not include ScaledFloatFrame wrapper, "
"please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)",
)
def _transform_task(self, task: str) -> str:
task = super()._transform_task(task)
return task.replace("NoFrameskip-v4", "-v5")
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
kwargs = super()._transform_kwargs(kwargs, mode)
is_train = mode == EnvMode.TRAIN
kwargs["reward_clip"] = is_train
kwargs["episodic_life"] = is_train
kwargs["stack_num"] = self.parent.frame_stack
return kwargs
class AtariStopCallback(TrainerStopCallback):

View File

@ -1,11 +1,13 @@
import logging
import pickle
import warnings
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.env import VectorEnvNormObs
from tianshou.highlevel.env import (
ContinuousEnvironments,
EnvFactoryGymnasium,
EnvPoolFactory,
VectorEnvType,
)
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World
@ -24,25 +26,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
:return: a tuple of (single env, training envs, test envs).
"""
if envpool is not None:
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
else:
warnings.warn(
"Recommend using envpool (pip install envpool) "
"to run Mujoco environments more efficiently.",
)
env = gym.make(task)
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
train_envs.seed(seed)
test_envs.seed(seed)
if obs_norm:
# obs norm wrapper
train_envs = VectorEnvNormObs(train_envs)
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
test_envs.set_obs_rms(train_envs.get_obs_rms())
return env, train_envs, test_envs
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
num_train_envs,
num_test_envs,
)
return envs.env, envs.train_envs, envs.test_envs
class MujocoEnvObsRmsPersistence(Persistence):
@ -68,21 +56,25 @@ class MujocoEnvObsRmsPersistence(Persistence):
world.envs.test_envs.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactory):
class MujocoEnvFactory(EnvFactoryGymnasium):
def __init__(self, task: str, seed: int, obs_norm=True):
self.task = task
self.seed = seed
super().__init__(
task=task,
seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=EnvPoolFactory(),
)
self.obs_norm = obs_norm
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.task,
seed=self.seed,
num_train_envs=num_training_envs,
num_test_envs=num_test_envs,
obs_norm=self.obs_norm,
)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs = super().create_envs(num_training_envs, num_test_envs)
assert isinstance(envs, ContinuousEnvironments)
# obs norm wrapper
if self.obs_norm:
envs.train_envs = VectorEnvNormObs(envs.train_envs)
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs

View File

@ -1,27 +1,14 @@
import gymnasium as gym
from tianshou.env import DummyVectorEnv
from tianshou.highlevel.env import (
ContinuousEnvironments,
DiscreteEnvironments,
EnvFactory,
Environments,
EnvFactoryGymnasium,
VectorEnvType,
)
class DiscreteTestEnvFactory(EnvFactory):
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
task = "CartPole-v0"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return DiscreteEnvironments(env, train_envs, test_envs)
class DiscreteTestEnvFactory(EnvFactoryGymnasium):
def __init__(self):
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
class ContinuousTestEnvFactory(EnvFactory):
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
task = "Pendulum-v1"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return ContinuousEnvironments(env, train_envs, test_envs)
class ContinuousTestEnvFactory(EnvFactoryGymnasium):
def __init__(self):
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)

View File

@ -1,9 +1,12 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any, TypeAlias, cast
import gymnasium as gym
import gymnasium.spaces
from gymnasium import Env
from tianshou.env import (
BaseVectorEnv,
@ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin
TObservationShape: TypeAlias = int | Sequence[int]
log = logging.getLogger(__name__)
class EnvType(Enum):
"""Enumeration of environment types."""
@ -39,6 +44,23 @@ class EnvType(Enum):
if not self.is_discrete():
raise AssertionError(f"{requiring_entity} requires discrete environments")
@staticmethod
def from_env(env: Env) -> "EnvType":
if isinstance(env.action_space, gymnasium.spaces.Discrete):
return EnvType.DISCRETE
elif isinstance(env.action_space, gymnasium.spaces.Box):
return EnvType.CONTINUOUS
else:
raise Exception(f"Unsupported environment type with action space {env.action_space}")
class EnvMode(Enum):
"""Indicates the purpose for which an environment is created."""
TRAIN = "train"
TEST = "test"
WATCH = "watch"
class VectorEnvType(Enum):
DUMMY = "dummy"
@ -65,7 +87,7 @@ class VectorEnvType(Enum):
class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments."""
"""Represents (vectorized) environments for a learning process."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
@ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC):
@staticmethod
def from_factory_and_type(
factory_fn: Callable[[], gym.Env],
factory_fn: Callable[[EnvMode], gym.Env],
env_type: EnvType,
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
@ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
if test_factory_fn is None:
test_factory_fn = factory_fn
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
env = factory_fn()
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
env = factory_fn(EnvMode.TRAIN)
match env_type:
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs)
@ -153,11 +170,10 @@ class ContinuousEnvironments(Environments):
@staticmethod
def from_factory(
factory_fn: Callable[[], gym.Env],
factory_fn: Callable[[EnvMode], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -165,8 +181,6 @@ class ContinuousEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
@ -177,7 +191,6 @@ class ContinuousEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)
@ -222,11 +235,10 @@ class DiscreteEnvironments(Environments):
@staticmethod
def from_factory(
factory_fn: Callable[[], gym.Env],
factory_fn: Callable[[EnvMode], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -234,8 +246,6 @@ class DiscreteEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
@ -246,7 +256,6 @@ class DiscreteEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)
@ -260,7 +269,156 @@ class DiscreteEnvironments(Environments):
return EnvType.DISCRETE
class EnvPoolFactory:
"""A factory for the creation of envpool-based vectorized environments."""
def _transform_task(self, task: str) -> str:
return task
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
"""Transforms gymnasium keyword arguments to be envpool-compatible.
:param kwargs: keyword arguments that would normally be passed to `gymnasium.make`.
:param mode: the environment mode
:return: the transformed keyword arguments
"""
kwargs = dict(kwargs)
if "render_mode" in kwargs:
del kwargs["render_mode"]
return kwargs
def create_venv(
self,
task: str,
num_envs: int,
mode: EnvMode,
seed: int,
kwargs: dict,
) -> BaseVectorEnv | None:
try:
import envpool
envpool_task = self._transform_task(task)
envpool_kwargs = self._transform_kwargs(kwargs, mode)
return envpool.make_gymnasium(
envpool_task,
num_envs=num_envs,
seed=seed,
**envpool_kwargs,
)
except ImportError:
return None
class EnvFactory(ToStringMixin, ABC):
def __init__(self, venv_type: VectorEnvType):
""":param venv_type: the type of vectorized environment to use"""
self.venv_type = venv_type
@abstractmethod
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
def create_env(self, mode: EnvMode) -> Env:
pass
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
:param mode: the mode for which to create
:return: the vectorized environments
"""
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
"""Create environments for learning.
:param num_training_envs: the number of training environments
:param num_test_envs: the number of test environments
:return: the environments
"""
env = self.create_env(EnvMode.TRAIN)
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
match EnvType.from_env(env):
case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs)
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs)
case _:
raise ValueError
class EnvFactoryGymnasium(EnvFactory):
"""Factory for environments that can be created via `gymnasium.make` (or via `envpool.make_gymnasium`)."""
def __init__(
self,
*,
task: str,
seed: int,
venv_type: VectorEnvType,
envpool_factory: EnvPoolFactory | None = None,
render_mode_train: str | None = None,
render_mode_test: str | None = None,
render_mode_watch: str = "human",
**kwargs: Any,
):
""":param task: the gymnasium task/environment identifier
:param seed: the random seed
:param venv_type: the type of vectorized environment to use. If `envpool_factory` is specified, this is but a fallback.
:param envpool_factory: the factory to use for envpool-based vectorized environment creation if `envpool` is installed.
If it is not installed, `venv_type` applies as a fallback.
:param render_mode_train: the render mode to use for training environments
:param render_mode_test: the render mode to use for test environments
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
:param kwargs: additional keyword arguments to pass on to `gymnasium.make`.
If envpool is used, the gymnasium parameters will be appropriately translated for use with
`envpool.make_gymnasium`.
"""
super().__init__(venv_type)
self.task = task
self.envpool_factory = envpool_factory
self.seed = seed
self.render_modes = {
EnvMode.TRAIN: render_mode_train,
EnvMode.TEST: render_mode_test,
EnvMode.WATCH: render_mode_watch,
}
self.kwargs = kwargs
def _create_kwargs(self, mode: EnvMode) -> dict:
"""Adapts the keyword arguments for the given mode.
:param mode: the mode
:return: adapted keyword arguments
"""
kwargs = dict(self.kwargs)
kwargs["render_mode"] = self.render_modes.get(mode)
return kwargs
def create_env(self, mode: EnvMode) -> Env:
"""Creates a single environment for the given mode.
:param mode: the mode
:return: an environment
"""
kwargs = self._create_kwargs(mode)
return gymnasium.make(self.task, **kwargs)
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
if self.envpool_factory is not None:
venv = self.envpool_factory.create_venv(
self.task,
num_envs,
mode,
self.seed,
self._create_kwargs(mode),
)
if venv is not None:
return venv
log.debug(
f"EnvPool-based creation could not be applied, falling back to default based on {self.venv_type}",
)
venv = super().create_venv(num_envs, mode)
venv.seed(self.seed)
return venv

View File

@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
TRPOAgentFactory,
)
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.env import EnvFactory, EnvMode
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import (
ActorFactory,
@ -293,7 +293,7 @@ class Experiment(ToStringMixin):
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.env_factory,
self.config.watch_render,
)
@ -303,15 +303,18 @@ class Experiment(ToStringMixin):
def _watch_agent(
num_episodes: int,
policy: BasePolicy,
test_collector: Collector,
env_factory: EnvFactory,
render: float,
) -> None:
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
env = env_factory.create_env(EnvMode.WATCH)
collector = Collector(policy, env)
result = collector.collect(n_episode=num_episodes, render=render)
assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")
log.info(
f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}",
)
class ExperimentBuilder: