Fix/add watch env with obs rms (#1061)

Supports deciding whether to watch the agent performing on the env using high-level interfaces
This commit is contained in:
maxhuettenrauch 2024-02-29 15:59:11 +01:00 committed by GitHub
parent 49781e715e
commit 7c970df53f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 81 additions and 22 deletions

View File

@ -1,10 +1,11 @@
import logging
import pickle
from tianshou.env import VectorEnvNormObs
from tianshou.env import BaseVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import (
ContinuousEnvironments,
EnvFactoryRegistered,
EnvMode,
EnvPoolFactory,
VectorEnvType,
)
@ -56,6 +57,8 @@ class MujocoEnvObsRmsPersistence(Persistence):
obs_rms = pickle.load(f)
world.envs.train_envs.set_obs_rms(obs_rms)
world.envs.test_envs.set_obs_rms(obs_rms)
if world.envs.watch_env is not None:
world.envs.watch_env.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactoryRegistered):
@ -68,15 +71,31 @@ class MujocoEnvFactory(EnvFactoryRegistered):
)
self.obs_norm = obs_norm
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
envs = super().create_envs(num_training_envs, num_test_envs)
assert isinstance(envs, ContinuousEnvironments)
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
:param mode: the mode for which to create
:return: the vectorized environments
"""
env = super().create_venv(num_envs, mode)
# obs norm wrapper
if self.obs_norm:
envs.train_envs = VectorEnvNormObs(envs.train_envs)
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
return env
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> ContinuousEnvironments:
envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
assert isinstance(envs, ContinuousEnvironments)
if self.obs_norm:
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
if envs.watch_env is not None:
envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs

View File

@ -89,10 +89,17 @@ class VectorEnvType(Enum):
class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments for a learning process."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
def __init__(
self,
env: gym.Env,
train_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
self.watch_env = watch_env
self.persistence: Sequence[Persistence] = []
@staticmethod
@ -102,6 +109,7 @@ class Environments(ToStringMixin, ABC):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
@ -110,16 +118,21 @@ class Environments(ToStringMixin, ABC):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance
"""
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
if create_watch_env:
watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)])
else:
watch_env = None
env = factory_fn(EnvMode.TRAIN)
match env_type:
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs)
return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs)
return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
case _:
raise ValueError(f"Environment type {env_type} not handled")
@ -164,8 +177,14 @@ class Environments(ToStringMixin, ABC):
class ContinuousEnvironments(Environments):
"""Represents (vectorized) continuous environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
def __init__(
self,
env: gym.Env,
train_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
super().__init__(env, train_envs, test_envs, watch_env)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
@staticmethod
@ -174,6 +193,7 @@ class ContinuousEnvironments(Environments):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -181,6 +201,7 @@ class ContinuousEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance
"""
return cast(
@ -191,6 +212,7 @@ class ContinuousEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
create_watch_env,
),
)
@ -228,8 +250,14 @@ class ContinuousEnvironments(Environments):
class DiscreteEnvironments(Environments):
"""Represents (vectorized) discrete environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
def __init__(
self,
env: gym.Env,
train_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
super().__init__(env, train_envs, test_envs, watch_env)
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
@ -239,6 +267,7 @@ class DiscreteEnvironments(Environments):
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -246,6 +275,7 @@ class DiscreteEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance
"""
return cast(
@ -256,6 +286,7 @@ class DiscreteEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
create_watch_env,
),
)
@ -329,21 +360,28 @@ class EnvFactory(ToStringMixin, ABC):
"""
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> Environments:
"""Create environments for learning.
:param num_training_envs: the number of training environments
:param num_test_envs: the number of test environments
:param create_watch_env: whether to create an environment for watching the agent
:return: the environments
"""
env = self.create_env(EnvMode.TRAIN)
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None
match EnvType.from_env(env):
case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs)
return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs)
return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
case _:
raise ValueError

View File

@ -10,6 +10,7 @@ import numpy as np
import torch
from tianshou.data import Collector, InfoStats
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.agent import (
A2CAgentFactory,
AgentFactory,
@ -26,7 +27,7 @@ from tianshou.highlevel.agent import (
TRPOAgentFactory,
)
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, EnvMode
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import (
ActorFactory,
@ -222,6 +223,7 @@ class Experiment(ToStringMixin):
envs = self.env_factory.create_envs(
self.sampling_config.num_train_envs,
self.sampling_config.num_test_envs,
create_watch_env=self.config.watch,
)
log.info(f"Created {envs}")
@ -289,11 +291,12 @@ class Experiment(ToStringMixin):
# watch agent performance
if self.config.watch:
assert envs.watch_env is not None
log.info("Watching agent performance")
self._watch_agent(
self.config.watch_num_episodes,
policy,
self.env_factory,
envs.watch_env,
self.config.watch_render,
)
@ -303,11 +306,10 @@ class Experiment(ToStringMixin):
def _watch_agent(
num_episodes: int,
policy: BasePolicy,
env_factory: EnvFactory,
env: BaseVectorEnv,
render: float,
) -> None:
policy.eval()
env = env_factory.create_env(EnvMode.WATCH)
collector = Collector(policy, env)
result = collector.collect(n_episode=num_episodes, render=render)
assert result.returns_stat is not None # for mypy