2023-10-12 17:40:16 +02:00
|
|
|
import logging
|
2023-10-12 15:01:49 +02:00
|
|
|
import pickle
|
2022-05-05 07:55:15 -04:00
|
|
|
|
2024-02-29 15:59:11 +01:00
|
|
|
from tianshou.env import BaseVectorEnv, VectorEnvNormObs
|
2024-01-10 15:37:58 +01:00
|
|
|
from tianshou.highlevel.env import (
|
|
|
|
ContinuousEnvironments,
|
2024-01-16 12:22:07 +01:00
|
|
|
EnvFactoryRegistered,
|
2024-02-29 15:59:11 +01:00
|
|
|
EnvMode,
|
2024-01-10 15:37:58 +01:00
|
|
|
EnvPoolFactory,
|
|
|
|
VectorEnvType,
|
|
|
|
)
|
2023-10-12 17:40:16 +02:00
|
|
|
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
2023-10-12 15:01:49 +02:00
|
|
|
from tianshou.highlevel.world import World
|
2022-05-17 17:41:59 +02:00
|
|
|
|
2024-01-16 12:16:46 +01:00
|
|
|
envpool_is_available = True
|
2022-05-05 07:55:15 -04:00
|
|
|
try:
|
|
|
|
import envpool
|
|
|
|
except ImportError:
|
2024-01-16 12:16:46 +01:00
|
|
|
envpool_is_available = False
|
2022-05-05 07:55:15 -04:00
|
|
|
envpool = None
|
|
|
|
|
2023-10-12 15:01:49 +02:00
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
2022-05-05 07:55:15 -04:00
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
|
2022-05-05 07:55:15 -04:00
|
|
|
"""Wrapper function for Mujoco env.
|
|
|
|
|
|
|
|
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
|
|
|
|
|
|
|
|
:return: a tuple of (single env, training envs, test envs).
|
|
|
|
"""
|
2024-01-10 15:37:58 +01:00
|
|
|
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
|
|
|
|
num_train_envs,
|
|
|
|
num_test_envs,
|
|
|
|
)
|
|
|
|
return envs.env, envs.train_envs, envs.test_envs
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
|
2023-10-12 15:01:49 +02:00
|
|
|
class MujocoEnvObsRmsPersistence(Persistence):
|
|
|
|
FILENAME = "env_obs_rms.pkl"
|
|
|
|
|
|
|
|
def persist(self, event: PersistEvent, world: World) -> None:
|
|
|
|
if event != PersistEvent.PERSIST_POLICY:
|
|
|
|
return
|
|
|
|
obs_rms = world.envs.train_envs.get_obs_rms()
|
|
|
|
path = world.persist_path(self.FILENAME)
|
|
|
|
log.info(f"Saving environment obs_rms value to {path}")
|
|
|
|
with open(path, "wb") as f:
|
|
|
|
pickle.dump(obs_rms, f)
|
|
|
|
|
|
|
|
def restore(self, event: RestoreEvent, world: World):
|
|
|
|
if event != RestoreEvent.RESTORE_POLICY:
|
|
|
|
return
|
|
|
|
path = world.restore_path(self.FILENAME)
|
|
|
|
log.info(f"Restoring environment obs_rms value from {path}")
|
|
|
|
with open(path, "rb") as f:
|
|
|
|
obs_rms = pickle.load(f)
|
|
|
|
world.envs.train_envs.set_obs_rms(obs_rms)
|
|
|
|
world.envs.test_envs.set_obs_rms(obs_rms)
|
2024-02-29 15:59:11 +01:00
|
|
|
if world.envs.watch_env is not None:
|
|
|
|
world.envs.watch_env.set_obs_rms(obs_rms)
|
2023-10-12 15:01:49 +02:00
|
|
|
|
|
|
|
|
2024-01-16 12:22:07 +01:00
|
|
|
class MujocoEnvFactory(EnvFactoryRegistered):
|
2024-03-14 11:07:56 +01:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
task: str,
|
|
|
|
seed: int,
|
|
|
|
obs_norm: bool = True,
|
|
|
|
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
|
|
|
|
) -> None:
|
2024-01-10 15:37:58 +01:00
|
|
|
super().__init__(
|
|
|
|
task=task,
|
|
|
|
seed=seed,
|
2024-03-14 11:07:56 +01:00
|
|
|
venv_type=venv_type,
|
2024-01-16 12:16:46 +01:00
|
|
|
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
|
2024-01-10 15:37:58 +01:00
|
|
|
)
|
2023-10-18 13:20:26 +02:00
|
|
|
self.obs_norm = obs_norm
|
2023-09-19 18:53:11 +02:00
|
|
|
|
2024-02-29 15:59:11 +01:00
|
|
|
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
|
|
|
"""Create vectorized environments.
|
2024-01-10 15:37:58 +01:00
|
|
|
|
2024-02-29 15:59:11 +01:00
|
|
|
:param num_envs: the number of environments
|
|
|
|
:param mode: the mode for which to create
|
|
|
|
:return: the vectorized environments
|
|
|
|
"""
|
|
|
|
env = super().create_venv(num_envs, mode)
|
2024-01-10 15:37:58 +01:00
|
|
|
# obs norm wrapper
|
2023-10-24 13:52:30 +02:00
|
|
|
if self.obs_norm:
|
2024-02-29 15:59:11 +01:00
|
|
|
env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
|
|
|
|
return env
|
|
|
|
|
|
|
|
def create_envs(
|
|
|
|
self,
|
|
|
|
num_training_envs: int,
|
|
|
|
num_test_envs: int,
|
|
|
|
create_watch_env: bool = False,
|
|
|
|
) -> ContinuousEnvironments:
|
|
|
|
envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
|
|
|
|
assert isinstance(envs, ContinuousEnvironments)
|
|
|
|
|
|
|
|
if self.obs_norm:
|
2024-01-10 15:37:58 +01:00
|
|
|
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
2024-02-29 15:59:11 +01:00
|
|
|
if envs.watch_env is not None:
|
|
|
|
envs.watch_env.set_obs_rms(envs.train_envs.get_obs_rms())
|
2023-10-24 13:52:30 +02:00
|
|
|
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
2023-10-12 15:01:49 +02:00
|
|
|
return envs
|