diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 210d465..b04f243 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,10 +1,11 @@ import logging import pickle -from tianshou.env import VectorEnvNormObs +from tianshou.env import BaseVectorEnv, VectorEnvNormObs from tianshou.highlevel.env import ( ContinuousEnvironments, EnvFactoryRegistered, + EnvMode, EnvPoolFactory, VectorEnvType, ) @@ -56,6 +57,8 @@ class MujocoEnvObsRmsPersistence(Persistence): obs_rms = pickle.load(f) world.envs.train_envs.set_obs_rms(obs_rms) world.envs.test_envs.set_obs_rms(obs_rms) + if world.envs.watch_env is not None: + world.envs.watch_env.set_obs_rms(obs_rms) class MujocoEnvFactory(EnvFactoryRegistered): @@ -68,15 +71,31 @@ class MujocoEnvFactory(EnvFactoryRegistered): ) self.obs_norm = obs_norm - def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: - envs = super().create_envs(num_training_envs, num_test_envs) - assert isinstance(envs, ContinuousEnvironments) + def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + """Create vectorized environments. + :param num_envs: the number of environments + :param mode: the mode for which to create + :return: the vectorized environments + """ + env = super().create_venv(num_envs, mode) # obs norm wrapper if self.obs_norm: - envs.train_envs = VectorEnvNormObs(envs.train_envs) - envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False) - envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) - envs.set_persistence(MujocoEnvObsRmsPersistence()) + env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN) + return env + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + create_watch_env: bool = False, + ) -> ContinuousEnvironments: + envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env) + assert isinstance(envs, ContinuousEnvironments) + + if self.obs_norm: + envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) + if envs.watch_env is not None: + envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms()) + envs.set_persistence(MujocoEnvObsRmsPersistence()) return envs diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index c343891..71de0f8 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -89,10 +89,17 @@ class VectorEnvType(Enum): class Environments(ToStringMixin, ABC): """Represents (vectorized) environments for a learning process.""" - def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + def __init__( + self, + env: gym.Env, + train_envs: BaseVectorEnv, + test_envs: BaseVectorEnv, + watch_env: BaseVectorEnv | None = None, + ): self.env = env self.train_envs = train_envs self.test_envs = test_envs + self.watch_env = watch_env self.persistence: Sequence[Persistence] = [] @staticmethod @@ -102,6 +109,7 @@ class Environments(ToStringMixin, ABC): venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, + create_watch_env: bool = False, ) -> "Environments": """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). @@ -110,16 +118,21 @@ class Environments(ToStringMixin, ABC): :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create + :param create_watch_env: whether to create an environment for watching the agent :return: the instance """ train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) + if create_watch_env: + watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)]) + else: + watch_env = None env = factory_fn(EnvMode.TRAIN) match env_type: case EnvType.CONTINUOUS: - return ContinuousEnvironments(env, train_envs, test_envs) + return ContinuousEnvironments(env, train_envs, test_envs, watch_env) case EnvType.DISCRETE: - return DiscreteEnvironments(env, train_envs, test_envs) + return DiscreteEnvironments(env, train_envs, test_envs, watch_env) case _: raise ValueError(f"Environment type {env_type} not handled") @@ -164,8 +177,14 @@ class Environments(ToStringMixin, ABC): class ContinuousEnvironments(Environments): """Represents (vectorized) continuous environments.""" - def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): - super().__init__(env, train_envs, test_envs) + def __init__( + self, + env: gym.Env, + train_envs: BaseVectorEnv, + test_envs: BaseVectorEnv, + watch_env: BaseVectorEnv | None = None, + ): + super().__init__(env, train_envs, test_envs, watch_env) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) @staticmethod @@ -174,6 +193,7 @@ class ContinuousEnvironments(Environments): venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, + create_watch_env: bool = False, ) -> "ContinuousEnvironments": """Creates an instance from a factory function that creates a single instance. @@ -181,6 +201,7 @@ class ContinuousEnvironments(Environments): :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create + :param create_watch_env: whether to create an environment for watching the agent :return: the instance """ return cast( @@ -191,6 +212,7 @@ class ContinuousEnvironments(Environments): venv_type, num_training_envs, num_test_envs, + create_watch_env, ), ) @@ -228,8 +250,14 @@ class ContinuousEnvironments(Environments): class DiscreteEnvironments(Environments): """Represents (vectorized) discrete environments.""" - def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): - super().__init__(env, train_envs, test_envs) + def __init__( + self, + env: gym.Env, + train_envs: BaseVectorEnv, + test_envs: BaseVectorEnv, + watch_env: BaseVectorEnv | None = None, + ): + super().__init__(env, train_envs, test_envs, watch_env) self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore self.action_shape = env.action_space.shape or env.action_space.n # type: ignore @@ -239,6 +267,7 @@ class DiscreteEnvironments(Environments): venv_type: VectorEnvType, num_training_envs: int, num_test_envs: int, + create_watch_env: bool = False, ) -> "DiscreteEnvironments": """Creates an instance from a factory function that creates a single instance. @@ -246,6 +275,7 @@ class DiscreteEnvironments(Environments): :param venv_type: the vector environment type to use for parallelization :param num_training_envs: the number of training environments to create :param num_test_envs: the number of test environments to create + :param create_watch_env: whether to create an environment for watching the agent :return: the instance """ return cast( @@ -256,6 +286,7 @@ class DiscreteEnvironments(Environments): venv_type, num_training_envs, num_test_envs, + create_watch_env, ), ) @@ -329,21 +360,28 @@ class EnvFactory(ToStringMixin, ABC): """ return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) - def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: + def create_envs( + self, + num_training_envs: int, + num_test_envs: int, + create_watch_env: bool = False, + ) -> Environments: """Create environments for learning. :param num_training_envs: the number of training environments :param num_test_envs: the number of test environments + :param create_watch_env: whether to create an environment for watching the agent :return: the environments """ env = self.create_env(EnvMode.TRAIN) train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) test_envs = self.create_venv(num_test_envs, EnvMode.TEST) + watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None match EnvType.from_env(env): case EnvType.DISCRETE: - return DiscreteEnvironments(env, train_envs, test_envs) + return DiscreteEnvironments(env, train_envs, test_envs, watch_env) case EnvType.CONTINUOUS: - return ContinuousEnvironments(env, train_envs, test_envs) + return ContinuousEnvironments(env, train_envs, test_envs, watch_env) case _: raise ValueError diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 6fa2853..5b7d388 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -10,6 +10,7 @@ import numpy as np import torch from tianshou.data import Collector, InfoStats +from tianshou.env import BaseVectorEnv from tianshou.highlevel.agent import ( A2CAgentFactory, AgentFactory, @@ -26,7 +27,7 @@ from tianshou.highlevel.agent import ( TRPOAgentFactory, ) from tianshou.highlevel.config import SamplingConfig -from tianshou.highlevel.env import EnvFactory, EnvMode +from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger from tianshou.highlevel.module.actor import ( ActorFactory, @@ -222,6 +223,7 @@ class Experiment(ToStringMixin): envs = self.env_factory.create_envs( self.sampling_config.num_train_envs, self.sampling_config.num_test_envs, + create_watch_env=self.config.watch, ) log.info(f"Created {envs}") @@ -289,11 +291,12 @@ class Experiment(ToStringMixin): # watch agent performance if self.config.watch: + assert envs.watch_env is not None log.info("Watching agent performance") self._watch_agent( self.config.watch_num_episodes, policy, - self.env_factory, + envs.watch_env, self.config.watch_render, ) @@ -303,11 +306,10 @@ class Experiment(ToStringMixin): def _watch_agent( num_episodes: int, policy: BasePolicy, - env_factory: EnvFactory, + env: BaseVectorEnv, render: float, ) -> None: policy.eval() - env = env_factory.create_env(EnvMode.WATCH) collector = Collector(policy, env) result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy