Changes: - Disclaimer in README - Replaced all occurences of Gym with Gymnasium - Removed code that is now dead since we no longer need to support the old step API - Updated type hints to only allow new step API - Increased required version of envpool to support Gymnasium - Increased required version of PettingZoo to support Gymnasium - Updated `PettingZooEnv` to only use the new step API, removed hack to also support old API - I had to add some `# type: ignore` comments, due to new type hinting in Gymnasium. I'm not that familiar with type hinting but I believe that the issue is on the Gymnasium side and we are looking into it. - Had to update `MyTestEnv` to support `options` kwarg - Skip NNI tests because they still use OpenAI Gym - Also allow `PettingZooEnv` in vector environment - Updated doc page about ReplayBuffer to also talk about terminated and truncated flags. Still need to do: - Update the Jupyter notebooks in docs - Check the entire code base for more dead code (from compatibility stuff) - Check the reset functions of all environments/wrappers in code base to make sure they use the `options` kwarg - Someone might want to check test_env_finite.py - Is it okay to allow `PettingZooEnv` in vector environments? Might need to update docs?
44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import warnings
|
|
|
|
import gymnasium as gym
|
|
|
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
|
|
|
try:
|
|
import envpool
|
|
except ImportError:
|
|
envpool = None
|
|
|
|
|
|
def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
|
|
"""Wrapper function for Mujoco env.
|
|
|
|
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
|
|
|
|
:return: a tuple of (single env, training envs, test envs).
|
|
"""
|
|
if envpool is not None:
|
|
train_envs = env = envpool.make_gymnasium(
|
|
task, num_envs=training_num, seed=seed
|
|
)
|
|
test_envs = envpool.make_gymnasium(task, num_envs=test_num, seed=seed)
|
|
else:
|
|
warnings.warn(
|
|
"Recommend using envpool (pip install envpool) "
|
|
"to run Mujoco environments more efficiently."
|
|
)
|
|
env = gym.make(task)
|
|
train_envs = ShmemVectorEnv(
|
|
[lambda: gym.make(task) for _ in range(training_num)]
|
|
)
|
|
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
|
env.seed(seed)
|
|
train_envs.seed(seed)
|
|
test_envs.seed(seed)
|
|
if obs_norm:
|
|
# obs norm wrapper
|
|
train_envs = VectorEnvNormObs(train_envs)
|
|
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
|
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
|
return env, train_envs, test_envs
|