Tianshou/test/base/test_action_space_sampling.py
Michael Panchenko 12d4262f80 Tests: removed all instances of if __name__ == ... in tests
A test is not a script and should not be used as such

Also marked pistonball test as skipped since it doesn't actually test anything
2024-04-26 17:39:30 +02:00

51 lines
1.4 KiB
Python

import gymnasium as gym
from tianshou.env import DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
def test_gym_env_action_space() -> None:
env = gym.make("Pendulum-v1")
env.action_space.seed(0)
action1 = env.action_space.sample()
env.action_space.seed(0)
action2 = env.action_space.sample()
assert action1 == action2
def test_dummy_vec_env_action_space() -> None:
num_envs = 4
envs = DummyVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]
envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]
assert action1 == action2
def test_subproc_vec_env_action_space() -> None:
num_envs = 4
envs = SubprocVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]
envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]
assert action1 == action2
def test_shmem_vec_env_action_space() -> None:
num_envs = 4
envs = ShmemVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]
envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]
assert action1 == action2