Fix/deterministic action space sampling in SubprocVectorEnv (#1103)

This commit is contained in:
maxhuettenrauch 2024-04-18 16:16:57 +02:00 committed by GitHub
parent 6935a111d9
commit a043711c10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 0 deletions

View File

@ -0,0 +1,57 @@
import gymnasium as gym
from tianshou.env import DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
def test_gym_env_action_space() -> None:
env = gym.make("Pendulum-v1")
env.action_space.seed(0)
action1 = env.action_space.sample()
env.action_space.seed(0)
action2 = env.action_space.sample()
assert action1 == action2
def test_dummy_vec_env_action_space() -> None:
num_envs = 4
envs = DummyVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]
envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]
assert action1 == action2
def test_subproc_vec_env_action_space() -> None:
num_envs = 4
envs = SubprocVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]
envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]
assert action1 == action2
def test_shmem_vec_env_action_space() -> None:
num_envs = 4
envs = ShmemVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
envs.seed(0)
action1 = [ac_space.sample() for ac_space in envs.action_space]
envs.seed(0)
action2 = [ac_space.sample() for ac_space in envs.action_space]
assert action1 == action2
if __name__ == "__main__":
test_gym_env_action_space()
test_dummy_vec_env_action_space()
test_subproc_vec_env_action_space()
test_shmem_vec_env_action_space()

View File

@ -126,6 +126,7 @@ def _worker(
if hasattr(env, "seed"):
p.send(env.seed(data))
else:
env.action_space.seed(seed=data)
env.reset(seed=data)
p.send(None)
elif cmd == "getattr":