Fix/add watch env with obs rms (#1061)
Supports deciding whether to watch the agent performing on the env using high-level interfaces
This commit is contained in:
		
							parent
							
								
									49781e715e
								
							
						
					
					
						commit
						7c970df53f
					
				| @ -1,10 +1,11 @@ | |||||||
| import logging | import logging | ||||||
| import pickle | import pickle | ||||||
| 
 | 
 | ||||||
| from tianshou.env import VectorEnvNormObs | from tianshou.env import BaseVectorEnv, VectorEnvNormObs | ||||||
| from tianshou.highlevel.env import ( | from tianshou.highlevel.env import ( | ||||||
|     ContinuousEnvironments, |     ContinuousEnvironments, | ||||||
|     EnvFactoryRegistered, |     EnvFactoryRegistered, | ||||||
|  |     EnvMode, | ||||||
|     EnvPoolFactory, |     EnvPoolFactory, | ||||||
|     VectorEnvType, |     VectorEnvType, | ||||||
| ) | ) | ||||||
| @ -56,6 +57,8 @@ class MujocoEnvObsRmsPersistence(Persistence): | |||||||
|             obs_rms = pickle.load(f) |             obs_rms = pickle.load(f) | ||||||
|         world.envs.train_envs.set_obs_rms(obs_rms) |         world.envs.train_envs.set_obs_rms(obs_rms) | ||||||
|         world.envs.test_envs.set_obs_rms(obs_rms) |         world.envs.test_envs.set_obs_rms(obs_rms) | ||||||
|  |         if world.envs.watch_env is not None: | ||||||
|  |             world.envs.watch_env.set_obs_rms(obs_rms) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MujocoEnvFactory(EnvFactoryRegistered): | class MujocoEnvFactory(EnvFactoryRegistered): | ||||||
| @ -68,15 +71,31 @@ class MujocoEnvFactory(EnvFactoryRegistered): | |||||||
|         ) |         ) | ||||||
|         self.obs_norm = obs_norm |         self.obs_norm = obs_norm | ||||||
| 
 | 
 | ||||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: |     def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: | ||||||
|         envs = super().create_envs(num_training_envs, num_test_envs) |         """Create vectorized environments. | ||||||
|         assert isinstance(envs, ContinuousEnvironments) |  | ||||||
| 
 | 
 | ||||||
|  |         :param num_envs: the number of environments | ||||||
|  |         :param mode: the mode for which to create | ||||||
|  |         :return: the vectorized environments | ||||||
|  |         """ | ||||||
|  |         env = super().create_venv(num_envs, mode) | ||||||
|         # obs norm wrapper |         # obs norm wrapper | ||||||
|         if self.obs_norm: |         if self.obs_norm: | ||||||
|             envs.train_envs = VectorEnvNormObs(envs.train_envs) |             env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN) | ||||||
|             envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False) |         return env | ||||||
|             envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) |  | ||||||
|             envs.set_persistence(MujocoEnvObsRmsPersistence()) |  | ||||||
| 
 | 
 | ||||||
|  |     def create_envs( | ||||||
|  |         self, | ||||||
|  |         num_training_envs: int, | ||||||
|  |         num_test_envs: int, | ||||||
|  |         create_watch_env: bool = False, | ||||||
|  |     ) -> ContinuousEnvironments: | ||||||
|  |         envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env) | ||||||
|  |         assert isinstance(envs, ContinuousEnvironments) | ||||||
|  | 
 | ||||||
|  |         if self.obs_norm: | ||||||
|  |             envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) | ||||||
|  |             if envs.watch_env is not None: | ||||||
|  |                 envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms()) | ||||||
|  |             envs.set_persistence(MujocoEnvObsRmsPersistence()) | ||||||
|         return envs |         return envs | ||||||
|  | |||||||
| @ -89,10 +89,17 @@ class VectorEnvType(Enum): | |||||||
| class Environments(ToStringMixin, ABC): | class Environments(ToStringMixin, ABC): | ||||||
|     """Represents (vectorized) environments for a learning process.""" |     """Represents (vectorized) environments for a learning process.""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): |     def __init__( | ||||||
|  |         self, | ||||||
|  |         env: gym.Env, | ||||||
|  |         train_envs: BaseVectorEnv, | ||||||
|  |         test_envs: BaseVectorEnv, | ||||||
|  |         watch_env: BaseVectorEnv | None = None, | ||||||
|  |     ): | ||||||
|         self.env = env |         self.env = env | ||||||
|         self.train_envs = train_envs |         self.train_envs = train_envs | ||||||
|         self.test_envs = test_envs |         self.test_envs = test_envs | ||||||
|  |         self.watch_env = watch_env | ||||||
|         self.persistence: Sequence[Persistence] = [] |         self.persistence: Sequence[Persistence] = [] | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
| @ -102,6 +109,7 @@ class Environments(ToStringMixin, ABC): | |||||||
|         venv_type: VectorEnvType, |         venv_type: VectorEnvType, | ||||||
|         num_training_envs: int, |         num_training_envs: int, | ||||||
|         num_test_envs: int, |         num_test_envs: int, | ||||||
|  |         create_watch_env: bool = False, | ||||||
|     ) -> "Environments": |     ) -> "Environments": | ||||||
|         """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). |         """Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete). | ||||||
| 
 | 
 | ||||||
| @ -110,16 +118,21 @@ class Environments(ToStringMixin, ABC): | |||||||
|         :param venv_type: the vector environment type to use for parallelization |         :param venv_type: the vector environment type to use for parallelization | ||||||
|         :param num_training_envs: the number of training environments to create |         :param num_training_envs: the number of training environments to create | ||||||
|         :param num_test_envs: the number of test environments to create |         :param num_test_envs: the number of test environments to create | ||||||
|  |         :param create_watch_env: whether to create an environment for watching the agent | ||||||
|         :return: the instance |         :return: the instance | ||||||
|         """ |         """ | ||||||
|         train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) |         train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs) | ||||||
|         test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) |         test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs) | ||||||
|  |         if create_watch_env: | ||||||
|  |             watch_env = venv_type.create_venv([lambda: factory_fn(EnvMode.WATCH)]) | ||||||
|  |         else: | ||||||
|  |             watch_env = None | ||||||
|         env = factory_fn(EnvMode.TRAIN) |         env = factory_fn(EnvMode.TRAIN) | ||||||
|         match env_type: |         match env_type: | ||||||
|             case EnvType.CONTINUOUS: |             case EnvType.CONTINUOUS: | ||||||
|                 return ContinuousEnvironments(env, train_envs, test_envs) |                 return ContinuousEnvironments(env, train_envs, test_envs, watch_env) | ||||||
|             case EnvType.DISCRETE: |             case EnvType.DISCRETE: | ||||||
|                 return DiscreteEnvironments(env, train_envs, test_envs) |                 return DiscreteEnvironments(env, train_envs, test_envs, watch_env) | ||||||
|             case _: |             case _: | ||||||
|                 raise ValueError(f"Environment type {env_type} not handled") |                 raise ValueError(f"Environment type {env_type} not handled") | ||||||
| 
 | 
 | ||||||
| @ -164,8 +177,14 @@ class Environments(ToStringMixin, ABC): | |||||||
| class ContinuousEnvironments(Environments): | class ContinuousEnvironments(Environments): | ||||||
|     """Represents (vectorized) continuous environments.""" |     """Represents (vectorized) continuous environments.""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): |     def __init__( | ||||||
|         super().__init__(env, train_envs, test_envs) |         self, | ||||||
|  |         env: gym.Env, | ||||||
|  |         train_envs: BaseVectorEnv, | ||||||
|  |         test_envs: BaseVectorEnv, | ||||||
|  |         watch_env: BaseVectorEnv | None = None, | ||||||
|  |     ): | ||||||
|  |         super().__init__(env, train_envs, test_envs, watch_env) | ||||||
|         self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) |         self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
| @ -174,6 +193,7 @@ class ContinuousEnvironments(Environments): | |||||||
|         venv_type: VectorEnvType, |         venv_type: VectorEnvType, | ||||||
|         num_training_envs: int, |         num_training_envs: int, | ||||||
|         num_test_envs: int, |         num_test_envs: int, | ||||||
|  |         create_watch_env: bool = False, | ||||||
|     ) -> "ContinuousEnvironments": |     ) -> "ContinuousEnvironments": | ||||||
|         """Creates an instance from a factory function that creates a single instance. |         """Creates an instance from a factory function that creates a single instance. | ||||||
| 
 | 
 | ||||||
| @ -181,6 +201,7 @@ class ContinuousEnvironments(Environments): | |||||||
|         :param venv_type: the vector environment type to use for parallelization |         :param venv_type: the vector environment type to use for parallelization | ||||||
|         :param num_training_envs: the number of training environments to create |         :param num_training_envs: the number of training environments to create | ||||||
|         :param num_test_envs: the number of test environments to create |         :param num_test_envs: the number of test environments to create | ||||||
|  |         :param create_watch_env: whether to create an environment for watching the agent | ||||||
|         :return: the instance |         :return: the instance | ||||||
|         """ |         """ | ||||||
|         return cast( |         return cast( | ||||||
| @ -191,6 +212,7 @@ class ContinuousEnvironments(Environments): | |||||||
|                 venv_type, |                 venv_type, | ||||||
|                 num_training_envs, |                 num_training_envs, | ||||||
|                 num_test_envs, |                 num_test_envs, | ||||||
|  |                 create_watch_env, | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -228,8 +250,14 @@ class ContinuousEnvironments(Environments): | |||||||
| class DiscreteEnvironments(Environments): | class DiscreteEnvironments(Environments): | ||||||
|     """Represents (vectorized) discrete environments.""" |     """Represents (vectorized) discrete environments.""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): |     def __init__( | ||||||
|         super().__init__(env, train_envs, test_envs) |         self, | ||||||
|  |         env: gym.Env, | ||||||
|  |         train_envs: BaseVectorEnv, | ||||||
|  |         test_envs: BaseVectorEnv, | ||||||
|  |         watch_env: BaseVectorEnv | None = None, | ||||||
|  |     ): | ||||||
|  |         super().__init__(env, train_envs, test_envs, watch_env) | ||||||
|         self.observation_shape = env.observation_space.shape or env.observation_space.n  # type: ignore |         self.observation_shape = env.observation_space.shape or env.observation_space.n  # type: ignore | ||||||
|         self.action_shape = env.action_space.shape or env.action_space.n  # type: ignore |         self.action_shape = env.action_space.shape or env.action_space.n  # type: ignore | ||||||
| 
 | 
 | ||||||
| @ -239,6 +267,7 @@ class DiscreteEnvironments(Environments): | |||||||
|         venv_type: VectorEnvType, |         venv_type: VectorEnvType, | ||||||
|         num_training_envs: int, |         num_training_envs: int, | ||||||
|         num_test_envs: int, |         num_test_envs: int, | ||||||
|  |         create_watch_env: bool = False, | ||||||
|     ) -> "DiscreteEnvironments": |     ) -> "DiscreteEnvironments": | ||||||
|         """Creates an instance from a factory function that creates a single instance. |         """Creates an instance from a factory function that creates a single instance. | ||||||
| 
 | 
 | ||||||
| @ -246,6 +275,7 @@ class DiscreteEnvironments(Environments): | |||||||
|         :param venv_type: the vector environment type to use for parallelization |         :param venv_type: the vector environment type to use for parallelization | ||||||
|         :param num_training_envs: the number of training environments to create |         :param num_training_envs: the number of training environments to create | ||||||
|         :param num_test_envs: the number of test environments to create |         :param num_test_envs: the number of test environments to create | ||||||
|  |         :param create_watch_env: whether to create an environment for watching the agent | ||||||
|         :return: the instance |         :return: the instance | ||||||
|         """ |         """ | ||||||
|         return cast( |         return cast( | ||||||
| @ -256,6 +286,7 @@ class DiscreteEnvironments(Environments): | |||||||
|                 venv_type, |                 venv_type, | ||||||
|                 num_training_envs, |                 num_training_envs, | ||||||
|                 num_test_envs, |                 num_test_envs, | ||||||
|  |                 create_watch_env, | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -329,21 +360,28 @@ class EnvFactory(ToStringMixin, ABC): | |||||||
|         """ |         """ | ||||||
|         return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) |         return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) | ||||||
| 
 | 
 | ||||||
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments: |     def create_envs( | ||||||
|  |         self, | ||||||
|  |         num_training_envs: int, | ||||||
|  |         num_test_envs: int, | ||||||
|  |         create_watch_env: bool = False, | ||||||
|  |     ) -> Environments: | ||||||
|         """Create environments for learning. |         """Create environments for learning. | ||||||
| 
 | 
 | ||||||
|         :param num_training_envs: the number of training environments |         :param num_training_envs: the number of training environments | ||||||
|         :param num_test_envs: the number of test environments |         :param num_test_envs: the number of test environments | ||||||
|  |         :param create_watch_env: whether to create an environment for watching the agent | ||||||
|         :return: the environments |         :return: the environments | ||||||
|         """ |         """ | ||||||
|         env = self.create_env(EnvMode.TRAIN) |         env = self.create_env(EnvMode.TRAIN) | ||||||
|         train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) |         train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) | ||||||
|         test_envs = self.create_venv(num_test_envs, EnvMode.TEST) |         test_envs = self.create_venv(num_test_envs, EnvMode.TEST) | ||||||
|  |         watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None | ||||||
|         match EnvType.from_env(env): |         match EnvType.from_env(env): | ||||||
|             case EnvType.DISCRETE: |             case EnvType.DISCRETE: | ||||||
|                 return DiscreteEnvironments(env, train_envs, test_envs) |                 return DiscreteEnvironments(env, train_envs, test_envs, watch_env) | ||||||
|             case EnvType.CONTINUOUS: |             case EnvType.CONTINUOUS: | ||||||
|                 return ContinuousEnvironments(env, train_envs, test_envs) |                 return ContinuousEnvironments(env, train_envs, test_envs, watch_env) | ||||||
|             case _: |             case _: | ||||||
|                 raise ValueError |                 raise ValueError | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -10,6 +10,7 @@ import numpy as np | |||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| from tianshou.data import Collector, InfoStats | from tianshou.data import Collector, InfoStats | ||||||
|  | from tianshou.env import BaseVectorEnv | ||||||
| from tianshou.highlevel.agent import ( | from tianshou.highlevel.agent import ( | ||||||
|     A2CAgentFactory, |     A2CAgentFactory, | ||||||
|     AgentFactory, |     AgentFactory, | ||||||
| @ -26,7 +27,7 @@ from tianshou.highlevel.agent import ( | |||||||
|     TRPOAgentFactory, |     TRPOAgentFactory, | ||||||
| ) | ) | ||||||
| from tianshou.highlevel.config import SamplingConfig | from tianshou.highlevel.config import SamplingConfig | ||||||
| from tianshou.highlevel.env import EnvFactory, EnvMode | from tianshou.highlevel.env import EnvFactory | ||||||
| from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger | from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger | ||||||
| from tianshou.highlevel.module.actor import ( | from tianshou.highlevel.module.actor import ( | ||||||
|     ActorFactory, |     ActorFactory, | ||||||
| @ -222,6 +223,7 @@ class Experiment(ToStringMixin): | |||||||
|             envs = self.env_factory.create_envs( |             envs = self.env_factory.create_envs( | ||||||
|                 self.sampling_config.num_train_envs, |                 self.sampling_config.num_train_envs, | ||||||
|                 self.sampling_config.num_test_envs, |                 self.sampling_config.num_test_envs, | ||||||
|  |                 create_watch_env=self.config.watch, | ||||||
|             ) |             ) | ||||||
|             log.info(f"Created {envs}") |             log.info(f"Created {envs}") | ||||||
| 
 | 
 | ||||||
| @ -289,11 +291,12 @@ class Experiment(ToStringMixin): | |||||||
| 
 | 
 | ||||||
|             # watch agent performance |             # watch agent performance | ||||||
|             if self.config.watch: |             if self.config.watch: | ||||||
|  |                 assert envs.watch_env is not None | ||||||
|                 log.info("Watching agent performance") |                 log.info("Watching agent performance") | ||||||
|                 self._watch_agent( |                 self._watch_agent( | ||||||
|                     self.config.watch_num_episodes, |                     self.config.watch_num_episodes, | ||||||
|                     policy, |                     policy, | ||||||
|                     self.env_factory, |                     envs.watch_env, | ||||||
|                     self.config.watch_render, |                     self.config.watch_render, | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
| @ -303,11 +306,10 @@ class Experiment(ToStringMixin): | |||||||
|     def _watch_agent( |     def _watch_agent( | ||||||
|         num_episodes: int, |         num_episodes: int, | ||||||
|         policy: BasePolicy, |         policy: BasePolicy, | ||||||
|         env_factory: EnvFactory, |         env: BaseVectorEnv, | ||||||
|         render: float, |         render: float, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         policy.eval() |         policy.eval() | ||||||
|         env = env_factory.create_env(EnvMode.WATCH) |  | ||||||
|         collector = Collector(policy, env) |         collector = Collector(policy, env) | ||||||
|         result = collector.collect(n_episode=num_episodes, render=render) |         result = collector.collect(n_episode=num_episodes, render=render) | ||||||
|         assert result.returns_stat is not None  # for mypy |         assert result.returns_stat is not None  # for mypy | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user