of number of environments in SamplingConfig is used (values are now passed to factory method) This is clearer and removes the need to pass otherwise unnecessary configuration to environment factories at construction
93 lines
3.2 KiB
Python
93 lines
3.2 KiB
Python
import logging
|
|
import pickle
|
|
import warnings
|
|
|
|
import gymnasium as gym
|
|
|
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
|
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
|
from tianshou.highlevel.world import World
|
|
|
|
try:
|
|
import envpool
|
|
except ImportError:
|
|
envpool = None
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
|
|
"""Wrapper function for Mujoco env.
|
|
|
|
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
|
|
|
|
:return: a tuple of (single env, training envs, test envs).
|
|
"""
|
|
if envpool is not None:
|
|
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
|
|
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
|
|
else:
|
|
warnings.warn(
|
|
"Recommend using envpool (pip install envpool) "
|
|
"to run Mujoco environments more efficiently.",
|
|
)
|
|
env = gym.make(task)
|
|
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
|
|
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
|
train_envs.seed(seed)
|
|
test_envs.seed(seed)
|
|
if obs_norm:
|
|
# obs norm wrapper
|
|
train_envs = VectorEnvNormObs(train_envs)
|
|
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
|
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
|
return env, train_envs, test_envs
|
|
|
|
|
|
class MujocoEnvObsRmsPersistence(Persistence):
|
|
FILENAME = "env_obs_rms.pkl"
|
|
|
|
def persist(self, event: PersistEvent, world: World) -> None:
|
|
if event != PersistEvent.PERSIST_POLICY:
|
|
return
|
|
obs_rms = world.envs.train_envs.get_obs_rms()
|
|
path = world.persist_path(self.FILENAME)
|
|
log.info(f"Saving environment obs_rms value to {path}")
|
|
with open(path, "wb") as f:
|
|
pickle.dump(obs_rms, f)
|
|
|
|
def restore(self, event: RestoreEvent, world: World):
|
|
if event != RestoreEvent.RESTORE_POLICY:
|
|
return
|
|
path = world.restore_path(self.FILENAME)
|
|
log.info(f"Restoring environment obs_rms value from {path}")
|
|
with open(path, "rb") as f:
|
|
obs_rms = pickle.load(f)
|
|
world.envs.train_envs.set_obs_rms(obs_rms)
|
|
world.envs.test_envs.set_obs_rms(obs_rms)
|
|
|
|
|
|
class MujocoEnvFactory(EnvFactory):
|
|
def __init__(self, task: str, seed: int, obs_norm=True):
|
|
self.task = task
|
|
self.seed = seed
|
|
self.obs_norm = obs_norm
|
|
|
|
def create_envs(
|
|
self,
|
|
num_training_envs: int,
|
|
num_test_envs: int,
|
|
config=None,
|
|
) -> ContinuousEnvironments:
|
|
env, train_envs, test_envs = make_mujoco_env(
|
|
task=self.task,
|
|
seed=self.seed,
|
|
num_train_envs=num_training_envs,
|
|
num_test_envs=num_test_envs,
|
|
obs_norm=self.obs_norm,
|
|
)
|
|
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
|
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
|
return envs
|