89 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			89 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| import pickle
 | |
| import warnings
 | |
| 
 | |
| import gymnasium as gym
 | |
| 
 | |
| from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
 | |
| from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
 | |
| from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
 | |
| from tianshou.highlevel.world import World
 | |
| 
 | |
| try:
 | |
|     import envpool
 | |
| except ImportError:
 | |
|     envpool = None
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
 | |
|     """Wrapper function for Mujoco env.
 | |
| 
 | |
|     If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
 | |
| 
 | |
|     :return: a tuple of (single env, training envs, test envs).
 | |
|     """
 | |
|     if envpool is not None:
 | |
|         train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
 | |
|         test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
 | |
|     else:
 | |
|         warnings.warn(
 | |
|             "Recommend using envpool (pip install envpool) "
 | |
|             "to run Mujoco environments more efficiently.",
 | |
|         )
 | |
|         env = gym.make(task)
 | |
|         train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
 | |
|         test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
 | |
|         train_envs.seed(seed)
 | |
|         test_envs.seed(seed)
 | |
|     if obs_norm:
 | |
|         # obs norm wrapper
 | |
|         train_envs = VectorEnvNormObs(train_envs)
 | |
|         test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
 | |
|         test_envs.set_obs_rms(train_envs.get_obs_rms())
 | |
|     return env, train_envs, test_envs
 | |
| 
 | |
| 
 | |
| class MujocoEnvObsRmsPersistence(Persistence):
 | |
|     FILENAME = "env_obs_rms.pkl"
 | |
| 
 | |
|     def persist(self, event: PersistEvent, world: World) -> None:
 | |
|         if event != PersistEvent.PERSIST_POLICY:
 | |
|             return
 | |
|         obs_rms = world.envs.train_envs.get_obs_rms()
 | |
|         path = world.persist_path(self.FILENAME)
 | |
|         log.info(f"Saving environment obs_rms value to {path}")
 | |
|         with open(path, "wb") as f:
 | |
|             pickle.dump(obs_rms, f)
 | |
| 
 | |
|     def restore(self, event: RestoreEvent, world: World):
 | |
|         if event != RestoreEvent.RESTORE_POLICY:
 | |
|             return
 | |
|         path = world.restore_path(self.FILENAME)
 | |
|         log.info(f"Restoring environment obs_rms value from {path}")
 | |
|         with open(path, "rb") as f:
 | |
|             obs_rms = pickle.load(f)
 | |
|         world.envs.train_envs.set_obs_rms(obs_rms)
 | |
|         world.envs.test_envs.set_obs_rms(obs_rms)
 | |
| 
 | |
| 
 | |
| class MujocoEnvFactory(EnvFactory):
 | |
|     def __init__(self, task: str, seed: int, obs_norm=True):
 | |
|         self.task = task
 | |
|         self.seed = seed
 | |
|         self.obs_norm = obs_norm
 | |
| 
 | |
|     def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
 | |
|         env, train_envs, test_envs = make_mujoco_env(
 | |
|             task=self.task,
 | |
|             seed=self.seed,
 | |
|             num_train_envs=num_training_envs,
 | |
|             num_test_envs=num_test_envs,
 | |
|             obs_norm=self.obs_norm,
 | |
|         )
 | |
|         envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
 | |
|         if self.obs_norm:
 | |
|             envs.set_persistence(MujocoEnvObsRmsPersistence())
 | |
|         return envs
 |