Fix/deterministic action space sampling in SubprocVectorEnv (#1103)
This commit is contained in:
parent
6935a111d9
commit
a043711c10
57
test/base/test_action_space_sampling.py
Normal file
57
test/base/test_action_space_sampling.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
from tianshou.env import DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_gym_env_action_space() -> None:
|
||||||
|
env = gym.make("Pendulum-v1")
|
||||||
|
env.action_space.seed(0)
|
||||||
|
action1 = env.action_space.sample()
|
||||||
|
|
||||||
|
env.action_space.seed(0)
|
||||||
|
action2 = env.action_space.sample()
|
||||||
|
|
||||||
|
assert action1 == action2
|
||||||
|
|
||||||
|
|
||||||
|
def test_dummy_vec_env_action_space() -> None:
|
||||||
|
num_envs = 4
|
||||||
|
envs = DummyVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
|
||||||
|
envs.seed(0)
|
||||||
|
action1 = [ac_space.sample() for ac_space in envs.action_space]
|
||||||
|
|
||||||
|
envs.seed(0)
|
||||||
|
action2 = [ac_space.sample() for ac_space in envs.action_space]
|
||||||
|
|
||||||
|
assert action1 == action2
|
||||||
|
|
||||||
|
|
||||||
|
def test_subproc_vec_env_action_space() -> None:
|
||||||
|
num_envs = 4
|
||||||
|
envs = SubprocVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
|
||||||
|
envs.seed(0)
|
||||||
|
action1 = [ac_space.sample() for ac_space in envs.action_space]
|
||||||
|
|
||||||
|
envs.seed(0)
|
||||||
|
action2 = [ac_space.sample() for ac_space in envs.action_space]
|
||||||
|
|
||||||
|
assert action1 == action2
|
||||||
|
|
||||||
|
|
||||||
|
def test_shmem_vec_env_action_space() -> None:
|
||||||
|
num_envs = 4
|
||||||
|
envs = ShmemVectorEnv([lambda: gym.make("Pendulum-v1") for _ in range(num_envs)])
|
||||||
|
envs.seed(0)
|
||||||
|
action1 = [ac_space.sample() for ac_space in envs.action_space]
|
||||||
|
|
||||||
|
envs.seed(0)
|
||||||
|
action2 = [ac_space.sample() for ac_space in envs.action_space]
|
||||||
|
|
||||||
|
assert action1 == action2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_gym_env_action_space()
|
||||||
|
test_dummy_vec_env_action_space()
|
||||||
|
test_subproc_vec_env_action_space()
|
||||||
|
test_shmem_vec_env_action_space()
|
1
tianshou/env/worker/subproc.py
vendored
1
tianshou/env/worker/subproc.py
vendored
@ -126,6 +126,7 @@ def _worker(
|
|||||||
if hasattr(env, "seed"):
|
if hasattr(env, "seed"):
|
||||||
p.send(env.seed(data))
|
p.send(env.seed(data))
|
||||||
else:
|
else:
|
||||||
|
env.action_space.seed(seed=data)
|
||||||
env.reset(seed=data)
|
env.reset(seed=data)
|
||||||
p.send(None)
|
p.send(None)
|
||||||
elif cmd == "getattr":
|
elif cmd == "getattr":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user