Tianshou/examples/mujoco/mujoco_env.py
Dominik Jain 3691ed2abc Support obs_rms persistence for MuJoCo by adding a general mechanism
for attaching persistence to Environments instances
2023-10-18 20:44:17 +02:00

89 lines
3.2 KiB
Python

import pickle
import warnings
import logging
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.persistence import Persistence, RestoreEvent, PersistEvent
from tianshou.highlevel.world import World
try:
import envpool
except ImportError:
envpool = None
log = logging.getLogger(__name__)
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
"""Wrapper function for Mujoco env.
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
:return: a tuple of (single env, training envs, test envs).
"""
if envpool is not None:
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
else:
warnings.warn(
"Recommend using envpool (pip install envpool) "
"to run Mujoco environments more efficiently.",
)
env = gym.make(task)
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
train_envs.seed(seed)
test_envs.seed(seed)
if obs_norm:
# obs norm wrapper
train_envs = VectorEnvNormObs(train_envs)
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
test_envs.set_obs_rms(train_envs.get_obs_rms())
return env, train_envs, test_envs
class MujocoEnvObsRmsPersistence(Persistence):
FILENAME = "env_obs_rms.pkl"
def persist(self, event: PersistEvent, world: World) -> None:
if event != PersistEvent.PERSIST_POLICY:
return
obs_rms = world.envs.train_envs.get_obs_rms()
path = world.persist_path(self.FILENAME)
log.info(f"Saving environment obs_rms value to {path}")
with open(path, "wb") as f:
pickle.dump(obs_rms, f)
def restore(self, event: RestoreEvent, world: World):
if event != RestoreEvent.RESTORE_POLICY:
return
path = world.restore_path(self.FILENAME)
log.info(f"Restoring environment obs_rms value from {path}")
with open(path, "rb") as f:
obs_rms = pickle.load(f)
world.envs.train_envs.set_obs_rms(obs_rms)
world.envs.test_envs.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, sampling_config: SamplingConfig):
self.task = task
self.sampling_config = sampling_config
self.seed = seed
def create_envs(self, config=None) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.task,
seed=self.seed,
num_train_envs=self.sampling_config.num_train_envs,
num_test_envs=self.sampling_config.num_test_envs,
obs_norm=True,
)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs