From a043711c10fc564e950e34ff1c45b1f8f06a57d6 Mon Sep 17 00:00:00 2001 From: maxhuettenrauch Date: Thu, 18 Apr 2024 16:16:57 +0200 Subject: [PATCH] Fix/deterministic action space sampling in SubprocVectorEnv (#1103) --- test/base/test_action_space_sampling.py | 57 +++++++++++++++++++++++++ tianshou/env/worker/subproc.py | 1 + 2 files changed, 58 insertions(+) create mode 100644 test/base/test_action_space_sampling.py diff --git a/test/base/test_action_space_sampling.py b/test/base/test_action_space_sampling.py new file mode 100644 index 0000000..fbbf25c --- /dev/null +++ b/test/base/test_action_space_sampling.py @@ -0,0 +1,57 @@ +import gymnasium as gym + +from tianshou.env import DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv + + +def test_gym_env_action_space() -> None: + env = gym.make("Pendulum-v1") + env.action_space.seed(0) + action1 = env.action_space.sample() + + env.action_space.seed(0) + action2 = env.action_space.sample() + + assert action1 == action2 + + +def test_dummy_vec_env_action_space() -> None: + num_envs = 4 + envs = DummyVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)]) + envs.seed(0) + action1 = [ac_space.sample() for ac_space in envs.action_space] + + envs.seed(0) + action2 = [ac_space.sample() for ac_space in envs.action_space] + + assert action1 == action2 + + +def test_subproc_vec_env_action_space() -> None: + num_envs = 4 + envs = SubprocVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)]) + envs.seed(0) + action1 = [ac_space.sample() for ac_space in envs.action_space] + + envs.seed(0) + action2 = [ac_space.sample() for ac_space in envs.action_space] + + assert action1 == action2 + + +def test_shmem_vec_env_action_space() -> None: + num_envs = 4 + envs = ShmemVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)]) + envs.seed(0) + action1 = [ac_space.sample() for ac_space in envs.action_space] + + envs.seed(0) + action2 = [ac_space.sample() for ac_space in envs.action_space] + + assert action1 == action2 + + +if __name__ == "__main__": + test_gym_env_action_space() + test_dummy_vec_env_action_space() + test_subproc_vec_env_action_space() + test_shmem_vec_env_action_space() diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index ab84ac0..2637df4 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -126,6 +126,7 @@ def _worker( if hasattr(env, "seed"): p.send(env.seed(data)) else: + env.action_space.seed(seed=data) env.reset(seed=data) p.send(None) elif cmd == "getattr":