import pickle import warnings import logging import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent from tianshou.highlevel.world import World try: import envpool except ImportError: envpool = None log = logging.getLogger(__name__) def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. :return: a tuple of (single env, training envs, test envs). """ if envpool is not None: train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed) test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed) else: warnings.warn( "Recommend using envpool (pip install envpool) " "to run Mujoco environments more efficiently.", ) env = gym.make(task) train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) train_envs.seed(seed) test_envs.seed(seed) if obs_norm: # obs norm wrapper train_envs = VectorEnvNormObs(train_envs) test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) test_envs.set_obs_rms(train_envs.get_obs_rms()) return env, train_envs, test_envs class MujocoEnvObsRmsPersistence(Persistence): FILENAME = "env_obs_rms.pkl" def persist(self, event: PersistEvent, world: World) -> None: if event != PersistEvent.PERSIST_POLICY: return obs_rms = world.envs.train_envs.get_obs_rms() path = world.persist_path(self.FILENAME) log.info(f"Saving environment obs_rms value to {path}") with open(path, "wb") as f: pickle.dump(obs_rms, f) def restore(self, event: RestoreEvent, world: World): if event != RestoreEvent.RESTORE_POLICY: return path = world.restore_path(self.FILENAME) log.info(f"Restoring environment obs_rms value from {path}") with open(path, "rb") as f: obs_rms = pickle.load(f) world.envs.train_envs.set_obs_rms(obs_rms) world.envs.test_envs.set_obs_rms(obs_rms) class MujocoEnvFactory(EnvFactory): def __init__(self, task: str, seed: int, sampling_config: SamplingConfig): self.task = task self.sampling_config = sampling_config self.seed = seed def create_envs(self, config=None) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( task=self.task, seed=self.seed, num_train_envs=self.sampling_config.num_train_envs, num_test_envs=self.sampling_config.num_test_envs, obs_norm=True, ) envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) envs.set_persistence(MujocoEnvObsRmsPersistence()) return envs