2022-05-05 07:55:15 -04:00
|
|
|
import warnings
|
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
import gymnasium as gym
|
2022-05-17 17:41:59 +02:00
|
|
|
|
|
|
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
|
|
|
|
2022-05-05 07:55:15 -04:00
|
|
|
try:
|
|
|
|
import envpool
|
|
|
|
except ImportError:
|
|
|
|
envpool = None
|
|
|
|
|
|
|
|
|
2023-07-26 20:24:33 +02:00
|
|
|
def make_mujoco_env(
|
|
|
|
task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool
|
|
|
|
):
|
2022-05-05 07:55:15 -04:00
|
|
|
"""Wrapper function for Mujoco env.
|
|
|
|
|
|
|
|
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
|
|
|
|
|
|
|
|
:return: a tuple of (single env, training envs, test envs).
|
|
|
|
"""
|
|
|
|
if envpool is not None:
|
2023-07-26 20:24:33 +02:00
|
|
|
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
|
|
|
|
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
|
2022-05-05 07:55:15 -04:00
|
|
|
else:
|
|
|
|
warnings.warn(
|
|
|
|
"Recommend using envpool (pip install envpool) "
|
2023-08-25 23:40:56 +02:00
|
|
|
"to run Mujoco environments more efficiently.",
|
2022-05-05 07:55:15 -04:00
|
|
|
)
|
|
|
|
env = gym.make(task)
|
2023-07-26 20:24:33 +02:00
|
|
|
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
|
|
|
|
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
2022-05-05 07:55:15 -04:00
|
|
|
train_envs.seed(seed)
|
|
|
|
test_envs.seed(seed)
|
|
|
|
if obs_norm:
|
|
|
|
# obs norm wrapper
|
|
|
|
train_envs = VectorEnvNormObs(train_envs)
|
|
|
|
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
|
|
|
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
|
|
|
return env, train_envs, test_envs
|