Fix/add watch env with obs rms (#1061)

Supports deciding whether to watch the agent performing on the env using high-level interfaces
This commit is contained in:
maxhuettenrauch 2024-02-29 15:59:11 +01:00 committed by GitHub
parent 49781e715e
commit 7c970df53f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 81 additions and 22 deletions

View File

@ -1,10 +1,11 @@
import logging import logging
import pickle import pickle
from tianshou.env import VectorEnvNormObs from tianshou.env import BaseVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ( from tianshou.highlevel.env import (
ContinuousEnvironments, ContinuousEnvironments,
EnvFactoryRegistered, EnvFactoryRegistered,
EnvMode,
EnvPoolFactory, EnvPoolFactory,
VectorEnvType, VectorEnvType,
) )
@ -56,6 +57,8 @@ class MujocoEnvObsRmsPersistence(Persistence):
obs_rms = pickle.load(f) obs_rms = pickle.load(f)
world.envs.train_envs.set_obs_rms(obs_rms) world.envs.train_envs.set_obs_rms(obs_rms)
world.envs.test_envs.set_obs_rms(obs_rms) world.envs.test_envs.set_obs_rms(obs_rms)
if world.envs.watch_env is not None:
world.envs.watch_env.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactoryRegistered): class MujocoEnvFactory(EnvFactoryRegistered):
@ -68,15 +71,31 @@ class MujocoEnvFactory(EnvFactoryRegistered):
) )
self.obs_norm = obs_norm self.obs_norm = obs_norm
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
envs = super().create_envs(num_training_envs, num_test_envs) """Create vectorized environments.
assert isinstance(envs, ContinuousEnvironments)
:param num_envs: the number of environments
:param mode: the mode for which to create
:return: the vectorized environments
"""
env = super().create_venv(num_envs, mode)
# obs norm wrapper # obs norm wrapper
if self.obs_norm: if self.obs_norm:
envs.train_envs = VectorEnvNormObs(envs.train_envs) env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False) return env
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> ContinuousEnvironments:
envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
assert isinstance(envs, ContinuousEnvironments)
if self.obs_norm:
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
if envs.watch_env is not None:
envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs return envs

View File

@ -89,10 +89,17 @@ class VectorEnvType(Enum):
class Environments(ToStringMixin, ABC): class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments for a learning process.""" """Represents (vectorized) environments for a learning process."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(
self,
env: gym.Env,
train_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
self.env = env self.env = env
self.train_envs = train_envs self.train_envs = train_envs
self.test_envs = test_envs self.test_envs = test_envs
self.watch_env = watch_env
self.persistence: Sequence[Persistence] = [] self.persistence: Sequence[Persistence] = []
@staticmethod @staticmethod
@ -102,6 +109,7 @@ class Environments(ToStringMixin, ABC):
venv_type: VectorEnvType, venv_type: VectorEnvType,
num_training_envs: int, num_training_envs: int,
num_test_envs: int, num_test_envs: int,
create_watch_env: bool = False,
) -> "Environments": ) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
@ -110,16 +118,21 @@ class Environments(ToStringMixin, ABC):
:param venv_type: the vector environment type to use for parallelization :param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create :param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create :param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance :return: the instance
""" """
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
if create_watch_env:
watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)])
else:
watch_env = None
env = factory_fn(EnvMode.TRAIN) env = factory_fn(EnvMode.TRAIN)
match env_type: match env_type:
case EnvType.CONTINUOUS: case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs) return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
case EnvType.DISCRETE: case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs) return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
case _: case _:
raise ValueError(f"Environment type {env_type} not handled") raise ValueError(f"Environment type {env_type} not handled")
@ -164,8 +177,14 @@ class Environments(ToStringMixin, ABC):
class ContinuousEnvironments(Environments): class ContinuousEnvironments(Environments):
"""Represents (vectorized) continuous environments.""" """Represents (vectorized) continuous environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(
super().__init__(env, train_envs, test_envs) self,
env: gym.Env,
train_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
super().__init__(env, train_envs, test_envs, watch_env)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
@staticmethod @staticmethod
@ -174,6 +193,7 @@ class ContinuousEnvironments(Environments):
venv_type: VectorEnvType, venv_type: VectorEnvType,
num_training_envs: int, num_training_envs: int,
num_test_envs: int, num_test_envs: int,
create_watch_env: bool = False,
) -> "ContinuousEnvironments": ) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance. """Creates an instance from a factory function that creates a single instance.
@ -181,6 +201,7 @@ class ContinuousEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization :param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create :param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create :param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance :return: the instance
""" """
return cast( return cast(
@ -191,6 +212,7 @@ class ContinuousEnvironments(Environments):
venv_type, venv_type,
num_training_envs, num_training_envs,
num_test_envs, num_test_envs,
create_watch_env,
), ),
) )
@ -228,8 +250,14 @@ class ContinuousEnvironments(Environments):
class DiscreteEnvironments(Environments): class DiscreteEnvironments(Environments):
"""Represents (vectorized) discrete environments.""" """Represents (vectorized) discrete environments."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(
super().__init__(env, train_envs, test_envs) self,
env: gym.Env,
train_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
super().__init__(env, train_envs, test_envs, watch_env)
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
@ -239,6 +267,7 @@ class DiscreteEnvironments(Environments):
venv_type: VectorEnvType, venv_type: VectorEnvType,
num_training_envs: int, num_training_envs: int,
num_test_envs: int, num_test_envs: int,
create_watch_env: bool = False,
) -> "DiscreteEnvironments": ) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance. """Creates an instance from a factory function that creates a single instance.
@ -246,6 +275,7 @@ class DiscreteEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization :param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create :param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create :param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance :return: the instance
""" """
return cast( return cast(
@ -256,6 +286,7 @@ class DiscreteEnvironments(Environments):
venv_type, venv_type,
num_training_envs, num_training_envs,
num_test_envs, num_test_envs,
create_watch_env,
), ),
) )
@ -329,21 +360,28 @@ class EnvFactory(ToStringMixin, ABC):
""" """
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> Environments:
"""Create environments for learning. """Create environments for learning.
:param num_training_envs: the number of training environments :param num_training_envs: the number of training environments
:param num_test_envs: the number of test environments :param num_test_envs: the number of test environments
:param create_watch_env: whether to create an environment for watching the agent
:return: the environments :return: the environments
""" """
env = self.create_env(EnvMode.TRAIN) env = self.create_env(EnvMode.TRAIN)
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
test_envs = self.create_venv(num_test_envs, EnvMode.TEST) test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None
match EnvType.from_env(env): match EnvType.from_env(env):
case EnvType.DISCRETE: case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs) return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
case EnvType.CONTINUOUS: case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs) return ContinuousEnvironments(env, train_envs, test_envs, watch_env)
case _: case _:
raise ValueError raise ValueError

View File

@ -10,6 +10,7 @@ import numpy as np
import torch import torch
from tianshou.data import Collector, InfoStats from tianshou.data import Collector, InfoStats
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.agent import ( from tianshou.highlevel.agent import (
A2CAgentFactory, A2CAgentFactory,
AgentFactory, AgentFactory,
@ -26,7 +27,7 @@ from tianshou.highlevel.agent import (
TRPOAgentFactory, TRPOAgentFactory,
) )
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, EnvMode from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
@ -222,6 +223,7 @@ class Experiment(ToStringMixin):
envs = self.env_factory.create_envs( envs = self.env_factory.create_envs(
self.sampling_config.num_train_envs, self.sampling_config.num_train_envs,
self.sampling_config.num_test_envs, self.sampling_config.num_test_envs,
create_watch_env=self.config.watch,
) )
log.info(f"Created {envs}") log.info(f"Created {envs}")
@ -289,11 +291,12 @@ class Experiment(ToStringMixin):
# watch agent performance # watch agent performance
if self.config.watch: if self.config.watch:
assert envs.watch_env is not None
log.info("Watching agent performance") log.info("Watching agent performance")
self._watch_agent( self._watch_agent(
self.config.watch_num_episodes, self.config.watch_num_episodes,
policy, policy,
self.env_factory, envs.watch_env,
self.config.watch_render, self.config.watch_render,
) )
@ -303,11 +306,10 @@ class Experiment(ToStringMixin):
def _watch_agent( def _watch_agent(
num_episodes: int, num_episodes: int,
policy: BasePolicy, policy: BasePolicy,
env_factory: EnvFactory, env: BaseVectorEnv,
render: float, render: float,
) -> None: ) -> None:
policy.eval() policy.eval()
env = env_factory.create_env(EnvMode.WATCH)
collector = Collector(policy, env) collector = Collector(policy, env)
result = collector.collect(n_episode=num_episodes, render=render) result = collector.collect(n_episode=num_episodes, render=render)
assert result.returns_stat is not None # for mypy assert result.returns_stat is not None # for mypy