| 
									
										
										
										
											2023-10-12 17:40:16 +02:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2023-10-12 15:01:49 +02:00
										 |  |  | import pickle | 
					
						
							| 
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  | from tianshou.env import VectorEnvNormObs | 
					
						
							|  |  |  | from tianshou.highlevel.env import ( | 
					
						
							|  |  |  |     ContinuousEnvironments, | 
					
						
							| 
									
										
										
										
											2024-01-16 12:22:07 +01:00
										 |  |  |     EnvFactoryRegistered, | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  |     EnvPoolFactory, | 
					
						
							|  |  |  |     VectorEnvType, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2023-10-12 17:40:16 +02:00
										 |  |  | from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent | 
					
						
							| 
									
										
										
										
											2023-10-12 15:01:49 +02:00
										 |  |  | from tianshou.highlevel.world import World | 
					
						
							| 
									
										
										
										
											2022-05-17 17:41:59 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-16 12:16:46 +01:00
										 |  |  | envpool_is_available = True | 
					
						
							| 
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 |  |  | try: | 
					
						
							|  |  |  |     import envpool | 
					
						
							|  |  |  | except ImportError: | 
					
						
							| 
									
										
										
										
											2024-01-16 12:16:46 +01:00
										 |  |  |     envpool_is_available = False | 
					
						
							| 
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 |  |  |     envpool = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-12 15:01:49 +02:00
										 |  |  | log = logging.getLogger(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-20 09:29:34 +02:00
										 |  |  | def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): | 
					
						
							| 
									
										
										
										
											2022-05-05 07:55:15 -04:00
										 |  |  |     """Wrapper function for Mujoco env.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     :return: a tuple of (single env, training envs, test envs). | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  |     envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs( | 
					
						
							|  |  |  |         num_train_envs, | 
					
						
							|  |  |  |         num_test_envs, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     return envs.env, envs.train_envs, envs.test_envs | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-12 15:01:49 +02:00
										 |  |  | class MujocoEnvObsRmsPersistence(Persistence): | 
					
						
							|  |  |  |     FILENAME = "env_obs_rms.pkl" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def persist(self, event: PersistEvent, world: World) -> None: | 
					
						
							|  |  |  |         if event != PersistEvent.PERSIST_POLICY: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  |         obs_rms = world.envs.train_envs.get_obs_rms() | 
					
						
							|  |  |  |         path = world.persist_path(self.FILENAME) | 
					
						
							|  |  |  |         log.info(f"Saving environment obs_rms value to {path}") | 
					
						
							|  |  |  |         with open(path, "wb") as f: | 
					
						
							|  |  |  |             pickle.dump(obs_rms, f) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def restore(self, event: RestoreEvent, world: World): | 
					
						
							|  |  |  |         if event != RestoreEvent.RESTORE_POLICY: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  |         path = world.restore_path(self.FILENAME) | 
					
						
							|  |  |  |         log.info(f"Restoring environment obs_rms value from {path}") | 
					
						
							|  |  |  |         with open(path, "rb") as f: | 
					
						
							|  |  |  |             obs_rms = pickle.load(f) | 
					
						
							|  |  |  |         world.envs.train_envs.set_obs_rms(obs_rms) | 
					
						
							|  |  |  |         world.envs.test_envs.set_obs_rms(obs_rms) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-16 12:22:07 +01:00
										 |  |  | class MujocoEnvFactory(EnvFactoryRegistered): | 
					
						
							| 
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 |  |  |     def __init__(self, task: str, seed: int, obs_norm=True) -> None: | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  |         super().__init__( | 
					
						
							|  |  |  |             task=task, | 
					
						
							|  |  |  |             seed=seed, | 
					
						
							|  |  |  |             venv_type=VectorEnvType.SUBPROC_SHARED_MEM, | 
					
						
							| 
									
										
										
										
											2024-01-16 12:16:46 +01:00
										 |  |  |             envpool_factory=EnvPoolFactory() if envpool_is_available else None, | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-10-18 13:20:26 +02:00
										 |  |  |         self.obs_norm = obs_norm | 
					
						
							| 
									
										
										
										
											2023-09-19 18:53:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-24 12:07:23 +02:00
										 |  |  |     def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  |         envs = super().create_envs(num_training_envs, num_test_envs) | 
					
						
							|  |  |  |         assert isinstance(envs, ContinuousEnvironments) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # obs norm wrapper | 
					
						
							| 
									
										
										
										
											2023-10-24 13:52:30 +02:00
										 |  |  |         if self.obs_norm: | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  |             envs.train_envs = VectorEnvNormObs(envs.train_envs) | 
					
						
							|  |  |  |             envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False) | 
					
						
							|  |  |  |             envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms()) | 
					
						
							| 
									
										
										
										
											2023-10-24 13:52:30 +02:00
										 |  |  |             envs.set_persistence(MujocoEnvObsRmsPersistence()) | 
					
						
							| 
									
										
										
										
											2024-01-10 15:37:58 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-12 15:01:49 +02:00
										 |  |  |         return envs |