Tianshou/examples/mujoco/mujoco_env.py
Markus Krimmel 6c6c872523
Gymnasium Integration (#789)
Changes:
- Disclaimer in README
- Replaced all occurences of Gym with Gymnasium
- Removed code that is now dead since we no longer need to support the
old step API
- Updated type hints to only allow new step API
- Increased required version of envpool to support Gymnasium
- Increased required version of PettingZoo to support Gymnasium
- Updated `PettingZooEnv` to only use the new step API, removed hack to
also support old API
- I had to add some `# type: ignore` comments, due to new type hinting
in Gymnasium. I'm not that familiar with type hinting but I believe that
the issue is on the Gymnasium side and we are looking into it.
- Had to update `MyTestEnv` to support `options` kwarg
- Skip NNI tests because they still use OpenAI Gym
- Also allow `PettingZooEnv` in vector environment
- Updated doc page about ReplayBuffer to also talk about terminated and
truncated flags.

Still need to do: 
- Update the Jupyter notebooks in docs
- Check the entire code base for more dead code (from compatibility
stuff)
- Check the reset functions of all environments/wrappers in code base to
make sure they use the `options` kwarg
- Someone might want to check test_env_finite.py
- Is it okay to allow `PettingZooEnv` in vector environments? Might need
to update docs?
2023-02-03 11:57:27 -08:00

44 lines
1.3 KiB
Python

import warnings
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
try:
import envpool
except ImportError:
envpool = None
def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
"""Wrapper function for Mujoco env.
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
:return: a tuple of (single env, training envs, test envs).
"""
if envpool is not None:
train_envs = env = envpool.make_gymnasium(
task, num_envs=training_num, seed=seed
)
test_envs = envpool.make_gymnasium(task, num_envs=test_num, seed=seed)
else:
warnings.warn(
"Recommend using envpool (pip install envpool) "
"to run Mujoco environments more efficiently."
)
env = gym.make(task)
train_envs = ShmemVectorEnv(
[lambda: gym.make(task) for _ in range(training_num)]
)
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
env.seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)
if obs_norm:
# obs norm wrapper
train_envs = VectorEnvNormObs(train_envs)
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
test_envs.set_obs_rms(train_envs.get_obs_rms())
return env, train_envs, test_envs