2021-04-25 15:23:46 +08:00
|
|
|
import sys
|
2020-03-11 10:56:38 +08:00
|
|
|
import time
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
import gymnasium as gym
|
2020-03-11 10:56:38 +08:00
|
|
|
import numpy as np
|
2022-05-05 07:55:15 -04:00
|
|
|
import pytest
|
2023-02-03 20:57:27 +01:00
|
|
|
from gymnasium.spaces.discrete import Discrete
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2020-07-26 12:01:21 +02:00
|
|
|
from tianshou.data import Batch
|
2022-05-05 07:55:15 -04:00
|
|
|
from tianshou.env import (
|
2022-06-13 00:18:22 +02:00
|
|
|
ContinuousToDiscrete,
|
2022-05-05 07:55:15 -04:00
|
|
|
DummyVectorEnv,
|
2022-06-13 00:18:22 +02:00
|
|
|
MultiDiscreteToDiscrete,
|
2022-05-05 07:55:15 -04:00
|
|
|
RayVectorEnv,
|
|
|
|
ShmemVectorEnv,
|
|
|
|
SubprocVectorEnv,
|
|
|
|
VectorEnvNormObs,
|
|
|
|
)
|
2022-10-31 08:54:54 +09:00
|
|
|
from tianshou.env.gym_wrappers import TruncatedAsTerminated
|
2022-05-05 07:55:15 -04:00
|
|
|
from tianshou.utils import RunningMeanStd
|
2020-03-11 10:56:38 +08:00
|
|
|
|
2022-07-31 11:00:38 -07:00
|
|
|
if __name__ == "__main__":
|
2021-04-25 15:23:46 +08:00
|
|
|
from env import MyTestEnv, NXEnv
|
2020-03-21 10:58:01 +08:00
|
|
|
else: # pytest
|
2021-04-25 15:23:46 +08:00
|
|
|
from test.base.env import MyTestEnv, NXEnv
|
2020-03-11 10:56:38 +08:00
|
|
|
|
2022-05-05 07:55:15 -04:00
|
|
|
try:
|
|
|
|
import envpool
|
|
|
|
except ImportError:
|
|
|
|
envpool = None
|
|
|
|
|
2020-03-11 16:14:53 +08:00
|
|
|
|
2020-08-19 15:00:24 +08:00
|
|
|
def has_ray():
|
|
|
|
try:
|
2021-03-31 15:14:22 +08:00
|
|
|
import ray # noqa: F401
|
2023-08-25 23:40:56 +02:00
|
|
|
|
2021-03-31 15:14:22 +08:00
|
|
|
return True
|
2020-08-19 15:00:24 +08:00
|
|
|
except ImportError:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2020-08-04 13:39:05 +08:00
|
|
|
def recurse_comp(a, b):
|
|
|
|
try:
|
|
|
|
if isinstance(a, np.ndarray):
|
2021-03-31 15:14:22 +08:00
|
|
|
if a.dtype == object:
|
2021-09-03 05:05:04 +08:00
|
|
|
return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
|
2023-08-25 23:40:56 +02:00
|
|
|
return np.allclose(a, b)
|
|
|
|
if isinstance(a, (list, tuple)):
|
2021-09-03 05:05:04 +08:00
|
|
|
return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
|
2023-08-25 23:40:56 +02:00
|
|
|
if isinstance(a, dict):
|
|
|
|
return np.array([recurse_comp(a[k], b[k]) for k in a]).all()
|
|
|
|
except Exception:
|
2020-08-04 13:39:05 +08:00
|
|
|
return False
|
|
|
|
|
|
|
|
|
2020-08-19 15:00:24 +08:00
|
|
|
def test_async_env(size=10000, num=8, sleep=0.1):
|
2020-07-26 12:01:21 +02:00
|
|
|
# simplify the test case, just keep stepping
|
|
|
|
env_fns = [
|
|
|
|
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
|
|
|
|
for i in range(size, size + num)
|
|
|
|
]
|
2020-08-19 15:00:24 +08:00
|
|
|
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
|
|
|
|
if has_ray():
|
|
|
|
test_cls += [RayVectorEnv]
|
|
|
|
for cls in test_cls:
|
|
|
|
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
|
2020-08-27 12:15:18 +08:00
|
|
|
v.seed(None)
|
2020-08-19 15:00:24 +08:00
|
|
|
v.reset()
|
|
|
|
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
|
|
|
|
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
|
|
|
|
# expectation of v is n / (n + 1)
|
|
|
|
# for a synchronous environment, the following actions should take
|
|
|
|
# about 7 * sleep * num / (num + 1) seconds
|
|
|
|
# for async simulation, the analysis is complicated, but the time cost
|
|
|
|
# should be smaller
|
|
|
|
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
|
|
|
|
current_idx_start = 0
|
2022-01-30 00:53:56 +08:00
|
|
|
act = action_list[:num]
|
2020-08-19 15:00:24 +08:00
|
|
|
env_ids = list(range(num))
|
|
|
|
o = []
|
|
|
|
spent_time = time.time()
|
|
|
|
while current_idx_start < len(action_list):
|
2023-08-25 23:40:56 +02:00
|
|
|
(
|
|
|
|
A,
|
|
|
|
B,
|
|
|
|
C,
|
|
|
|
D,
|
|
|
|
E,
|
|
|
|
) = v.step(action=act, id=env_ids)
|
2022-09-26 18:31:23 +02:00
|
|
|
b = Batch({"obs": A, "rew": B, "terminate": C, "truncated": D, "info": E})
|
2020-08-19 15:00:24 +08:00
|
|
|
env_ids = b.info.env_id
|
|
|
|
o.append(b)
|
2022-01-30 00:53:56 +08:00
|
|
|
current_idx_start += len(act)
|
2020-08-19 15:00:24 +08:00
|
|
|
# len of action may be smaller than len(A) in the end
|
2023-08-25 23:40:56 +02:00
|
|
|
act = action_list[current_idx_start : current_idx_start + len(A)]
|
2020-08-19 15:00:24 +08:00
|
|
|
# truncate env_ids with the first terms
|
|
|
|
# typically len(env_ids) == len(A) == len(action), except for the
|
|
|
|
# last batch when actions are not enough
|
2023-08-25 23:40:56 +02:00
|
|
|
env_ids = env_ids[: len(act)]
|
2020-08-19 15:00:24 +08:00
|
|
|
spent_time = time.time() - spent_time
|
|
|
|
Batch.cat(o)
|
|
|
|
v.close()
|
|
|
|
# assure 1/7 improvement
|
2022-04-12 10:29:39 -04:00
|
|
|
if sys.platform == "linux" and cls != RayVectorEnv:
|
|
|
|
# macOS/Windows cannot pass this check
|
2021-04-25 15:23:46 +08:00
|
|
|
assert spent_time < 6.0 * sleep * num / (num + 1)
|
2020-08-19 15:00:24 +08:00
|
|
|
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7):
|
2021-09-03 05:05:04 +08:00
|
|
|
env_fns = [
|
|
|
|
lambda: MyTestEnv(size=size, sleep=sleep * 2),
|
|
|
|
lambda: MyTestEnv(size=size, sleep=sleep * 3),
|
|
|
|
lambda: MyTestEnv(size=size, sleep=sleep * 5),
|
2023-08-25 23:40:56 +02:00
|
|
|
lambda: MyTestEnv(size=size, sleep=sleep * 7),
|
2021-09-03 05:05:04 +08:00
|
|
|
]
|
2020-08-19 15:00:24 +08:00
|
|
|
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
|
|
|
|
if has_ray():
|
|
|
|
test_cls += [RayVectorEnv]
|
2020-09-02 13:03:32 +08:00
|
|
|
total_pass = 0
|
2020-08-19 15:00:24 +08:00
|
|
|
for cls in test_cls:
|
2020-09-02 13:03:32 +08:00
|
|
|
pass_check = 1
|
2020-08-19 15:00:24 +08:00
|
|
|
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
|
2022-02-08 00:40:01 +08:00
|
|
|
t = time.time()
|
2020-08-19 15:00:24 +08:00
|
|
|
v.reset()
|
2022-02-08 00:40:01 +08:00
|
|
|
t = time.time() - t
|
|
|
|
print(f"{cls} reset {t}")
|
|
|
|
if t > sleep * 9: # huge than maximum sleep time (7 sleep)
|
|
|
|
pass_check = 0
|
2020-08-19 15:00:24 +08:00
|
|
|
expect_result = [
|
|
|
|
[0, 1],
|
|
|
|
[0, 1, 2],
|
|
|
|
[0, 1, 3],
|
|
|
|
[0, 1, 2],
|
|
|
|
[0, 1],
|
|
|
|
[0, 2, 3],
|
|
|
|
[0, 1],
|
|
|
|
]
|
|
|
|
ids = np.arange(num)
|
|
|
|
for res in expect_result:
|
|
|
|
t = time.time()
|
2022-09-26 18:31:23 +02:00
|
|
|
_, _, _, _, info = v.step([1] * len(ids), ids)
|
2020-08-19 15:00:24 +08:00
|
|
|
t = time.time() - t
|
|
|
|
ids = Batch(info).env_id
|
|
|
|
print(ids, t)
|
2021-09-03 05:05:04 +08:00
|
|
|
if not (
|
2023-08-25 23:40:56 +02:00
|
|
|
len(ids) == len(res)
|
|
|
|
and np.allclose(sorted(ids), res)
|
|
|
|
and (t < timeout) == (len(res) == num - 1)
|
2021-09-03 05:05:04 +08:00
|
|
|
):
|
2020-10-04 21:55:43 +08:00
|
|
|
pass_check = 0
|
|
|
|
break
|
2020-09-02 13:03:32 +08:00
|
|
|
total_pass += pass_check
|
2021-05-11 18:24:48 -07:00
|
|
|
if sys.platform == "linux": # Windows/macOS may not pass this check
|
2021-04-25 15:23:46 +08:00
|
|
|
assert total_pass >= 2
|
2020-07-26 12:01:21 +02:00
|
|
|
|
|
|
|
|
2020-03-16 11:11:29 +08:00
|
|
|
def test_vecenv(size=10, num=8, sleep=0.001):
|
2020-03-25 14:08:28 +08:00
|
|
|
env_fns = [
|
2020-08-04 13:39:05 +08:00
|
|
|
lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
|
2020-06-05 11:17:43 +02:00
|
|
|
for i in range(size, size + num)
|
2020-03-25 14:08:28 +08:00
|
|
|
]
|
2020-03-11 16:14:53 +08:00
|
|
|
venv = [
|
2020-08-19 15:00:24 +08:00
|
|
|
DummyVectorEnv(env_fns),
|
2020-03-25 14:08:28 +08:00
|
|
|
SubprocVectorEnv(env_fns),
|
2020-08-04 13:39:05 +08:00
|
|
|
ShmemVectorEnv(env_fns),
|
2020-03-11 16:14:53 +08:00
|
|
|
]
|
2021-11-22 22:21:02 +08:00
|
|
|
if has_ray() and sys.platform == "linux":
|
2020-08-19 15:00:24 +08:00
|
|
|
venv += [RayVectorEnv(env_fns)]
|
2020-03-11 16:14:53 +08:00
|
|
|
for v in venv:
|
2020-08-04 13:39:05 +08:00
|
|
|
v.seed(0)
|
2020-03-25 14:08:28 +08:00
|
|
|
action_list = [1] * 5 + [0] * 10 + [1] * 20
|
2022-09-26 18:31:23 +02:00
|
|
|
o = [v.reset()[0] for v in venv]
|
2021-09-03 05:05:04 +08:00
|
|
|
for a in action_list:
|
2020-08-19 15:00:24 +08:00
|
|
|
o = []
|
|
|
|
for v in venv:
|
2022-09-26 18:31:23 +02:00
|
|
|
A, B, C, D, E = v.step([a] * num)
|
|
|
|
if sum(C + D):
|
|
|
|
A, _ = v.reset(np.where(C + D)[0])
|
|
|
|
o.append([A, B, C, D, E])
|
2020-08-19 15:00:24 +08:00
|
|
|
for index, infos in enumerate(zip(*o)):
|
2022-09-26 18:31:23 +02:00
|
|
|
if index == 4: # do not check info here
|
2020-08-19 15:00:24 +08:00
|
|
|
continue
|
|
|
|
for info in infos:
|
|
|
|
assert recurse_comp(infos[0], info)
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2022-07-31 11:00:38 -07:00
|
|
|
if __name__ == "__main__":
|
2020-08-04 13:39:05 +08:00
|
|
|
t = [0] * len(venv)
|
2020-03-11 16:14:53 +08:00
|
|
|
for i, e in enumerate(venv):
|
|
|
|
t[i] = time.time()
|
|
|
|
e.reset()
|
|
|
|
for a in action_list:
|
2020-03-25 14:08:28 +08:00
|
|
|
done = e.step([a] * num)[2]
|
|
|
|
if sum(done) > 0:
|
|
|
|
e.reset(np.where(done)[0])
|
2020-03-11 16:14:53 +08:00
|
|
|
t[i] = time.time() - t[i]
|
2020-08-04 13:39:05 +08:00
|
|
|
for i, v in enumerate(venv):
|
2022-07-31 11:00:38 -07:00
|
|
|
print(f"{type(v)}: {t[i]:.6f}s")
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2021-11-02 17:08:00 +01:00
|
|
|
def assert_get(v, expected):
|
|
|
|
assert v.get_env_attr("size") == expected
|
|
|
|
assert v.get_env_attr("size", id=0) == [expected[0]]
|
|
|
|
assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]
|
|
|
|
|
2020-06-05 11:17:43 +02:00
|
|
|
for v in venv:
|
2021-11-02 17:08:00 +01:00
|
|
|
assert_get(v, list(range(size, size + num)))
|
2020-06-05 11:17:43 +02:00
|
|
|
assert v.env_num == num
|
2020-06-11 09:07:45 +08:00
|
|
|
assert v.action_space == [Discrete(2)] * num
|
2021-11-02 17:08:00 +01:00
|
|
|
|
|
|
|
v.set_env_attr("size", 0)
|
|
|
|
assert_get(v, [0] * num)
|
|
|
|
|
|
|
|
v.set_env_attr("size", 1, 0)
|
|
|
|
assert_get(v, [1] + [0] * (num - 1))
|
|
|
|
|
|
|
|
v.set_env_attr("size", 2, [1, 2, 3])
|
|
|
|
assert_get(v, [1] + [2] * 3 + [0] * (num - 4))
|
|
|
|
|
2020-03-11 17:28:51 +08:00
|
|
|
for v in venv:
|
|
|
|
v.close()
|
2020-03-11 16:14:53 +08:00
|
|
|
|
|
|
|
|
2022-07-14 22:52:56 -07:00
|
|
|
def test_attr_unwrapped():
|
|
|
|
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
|
|
|
|
train_envs.set_env_attr("test_attribute", 1337)
|
|
|
|
assert train_envs.get_env_attr("test_attribute") == [1337]
|
|
|
|
assert hasattr(train_envs.workers[0].env, "test_attribute")
|
|
|
|
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")
|
|
|
|
|
|
|
|
|
2022-05-05 07:55:15 -04:00
|
|
|
def test_env_obs_dtype():
|
2021-04-25 15:23:46 +08:00
|
|
|
for obs_type in ["array", "object"]:
|
2023-08-25 23:40:56 +02:00
|
|
|
envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]])
|
2022-09-26 18:31:23 +02:00
|
|
|
obs, info = envs.reset()
|
2021-04-25 15:23:46 +08:00
|
|
|
assert obs.dtype == object
|
|
|
|
obs = envs.step([1, 1, 1, 1])[0]
|
|
|
|
assert obs.dtype == object
|
|
|
|
|
|
|
|
|
2022-06-27 18:52:21 -04:00
|
|
|
def test_env_reset_optional_kwargs(size=10000, num=8):
|
|
|
|
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
|
|
|
|
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
|
|
|
|
if has_ray():
|
|
|
|
test_cls += [RayVectorEnv]
|
|
|
|
for cls in test_cls:
|
|
|
|
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
|
2022-09-26 18:31:23 +02:00
|
|
|
_, info = v.reset(seed=1)
|
2022-06-27 18:52:21 -04:00
|
|
|
assert len(info) == len(env_fns)
|
|
|
|
assert isinstance(info[0], dict)
|
|
|
|
|
|
|
|
|
2022-07-31 11:00:38 -07:00
|
|
|
def test_venv_wrapper_gym(num_envs: int = 4):
|
|
|
|
# Issue 697
|
|
|
|
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
|
|
|
|
envs = VectorEnvNormObs(envs)
|
2022-09-26 18:31:23 +02:00
|
|
|
try:
|
|
|
|
obs, info = envs.reset()
|
|
|
|
except ValueError:
|
|
|
|
obs, info = envs.reset(return_info=True)
|
2022-07-31 11:00:38 -07:00
|
|
|
assert isinstance(obs, np.ndarray)
|
|
|
|
assert isinstance(info, list)
|
|
|
|
assert isinstance(info[0], dict)
|
2022-09-26 18:31:23 +02:00
|
|
|
assert obs.shape[0] == len(info) == num_envs
|
2022-07-31 11:00:38 -07:00
|
|
|
|
|
|
|
|
2022-05-05 07:55:15 -04:00
|
|
|
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
|
2022-09-26 18:31:23 +02:00
|
|
|
def reset_result_to_obs(reset_result):
|
2023-08-25 23:40:56 +02:00
|
|
|
"""Extract observation from reset result (result is possibly a tuple containing info)."""
|
2022-09-26 18:31:23 +02:00
|
|
|
if isinstance(reset_result, tuple) and len(reset_result) == 2:
|
|
|
|
obs, _ = reset_result
|
|
|
|
else:
|
|
|
|
obs = reset_result
|
|
|
|
return obs
|
|
|
|
|
2022-05-05 07:55:15 -04:00
|
|
|
eps = np.finfo(np.float32).eps.item()
|
2022-09-26 18:31:23 +02:00
|
|
|
raw_reset_result = raw_env.reset()
|
|
|
|
train_reset_result = train_env.reset()
|
|
|
|
initial_raw_obs = reset_result_to_obs(raw_reset_result)
|
|
|
|
initial_train_obs = reset_result_to_obs(train_reset_result)
|
|
|
|
raw_obs, train_obs = [initial_raw_obs], [initial_train_obs]
|
2022-05-05 07:55:15 -04:00
|
|
|
for action in action_list:
|
2022-09-26 18:31:23 +02:00
|
|
|
step_result = raw_env.step(action)
|
|
|
|
if len(step_result) == 5:
|
|
|
|
obs, rew, terminated, truncated, info = step_result
|
|
|
|
done = np.logical_or(terminated, truncated)
|
|
|
|
else:
|
|
|
|
obs, rew, done, info = step_result
|
2022-05-05 07:55:15 -04:00
|
|
|
raw_obs.append(obs)
|
|
|
|
if np.any(done):
|
2022-09-26 18:31:23 +02:00
|
|
|
reset_result = raw_env.reset(np.where(done)[0])
|
|
|
|
obs = reset_result_to_obs(reset_result)
|
|
|
|
raw_obs.append(obs)
|
|
|
|
step_result = train_env.step(action)
|
|
|
|
if len(step_result) == 5:
|
|
|
|
obs, rew, terminated, truncated, info = step_result
|
|
|
|
done = np.logical_or(terminated, truncated)
|
|
|
|
else:
|
|
|
|
obs, rew, done, info = step_result
|
2022-05-05 07:55:15 -04:00
|
|
|
train_obs.append(obs)
|
|
|
|
if np.any(done):
|
2022-09-26 18:31:23 +02:00
|
|
|
reset_result = train_env.reset(np.where(done)[0])
|
|
|
|
obs = reset_result_to_obs(reset_result)
|
|
|
|
train_obs.append(obs)
|
2022-05-05 07:55:15 -04:00
|
|
|
ref_rms = RunningMeanStd()
|
|
|
|
for ro, to in zip(raw_obs, train_obs):
|
|
|
|
ref_rms.update(ro)
|
|
|
|
no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps)
|
|
|
|
assert np.allclose(no, to)
|
|
|
|
assert np.allclose(ref_rms.mean, train_env.get_obs_rms().mean)
|
|
|
|
assert np.allclose(ref_rms.var, train_env.get_obs_rms().var)
|
|
|
|
assert np.allclose(ref_rms.mean, test_env.get_obs_rms().mean)
|
|
|
|
assert np.allclose(ref_rms.var, test_env.get_obs_rms().var)
|
2022-09-26 18:31:23 +02:00
|
|
|
reset_result = test_env.reset()
|
|
|
|
obs = reset_result_to_obs(reset_result)
|
|
|
|
test_obs = [obs]
|
2022-05-05 07:55:15 -04:00
|
|
|
for action in action_list:
|
2022-09-26 18:31:23 +02:00
|
|
|
step_result = test_env.step(action)
|
|
|
|
if len(step_result) == 5:
|
|
|
|
obs, rew, terminated, truncated, info = step_result
|
|
|
|
done = np.logical_or(terminated, truncated)
|
|
|
|
else:
|
|
|
|
obs, rew, done, info = step_result
|
2022-05-05 07:55:15 -04:00
|
|
|
test_obs.append(obs)
|
|
|
|
if np.any(done):
|
2022-09-26 18:31:23 +02:00
|
|
|
reset_result = test_env.reset(np.where(done)[0])
|
|
|
|
obs = reset_result_to_obs(reset_result)
|
|
|
|
test_obs.append(obs)
|
2022-05-05 07:55:15 -04:00
|
|
|
for ro, to in zip(raw_obs, test_obs):
|
|
|
|
no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps)
|
|
|
|
assert np.allclose(no, to)
|
|
|
|
|
|
|
|
|
|
|
|
def test_venv_norm_obs():
|
|
|
|
sizes = np.array([5, 10, 15, 20])
|
|
|
|
action = np.array([1, 1, 1, 1])
|
|
|
|
total_step = 30
|
|
|
|
action_list = [action] * total_step
|
|
|
|
env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes]
|
|
|
|
raw = DummyVectorEnv(env_fns)
|
|
|
|
train_env = VectorEnvNormObs(DummyVectorEnv(env_fns))
|
|
|
|
print(train_env.observation_space)
|
|
|
|
test_env = VectorEnvNormObs(DummyVectorEnv(env_fns), update_obs_rms=False)
|
|
|
|
test_env.set_obs_rms(train_env.get_obs_rms())
|
|
|
|
run_align_norm_obs(raw, train_env, test_env, action_list)
|
|
|
|
|
|
|
|
|
2022-06-13 00:18:22 +02:00
|
|
|
def test_gym_wrappers():
|
|
|
|
class DummyEnv(gym.Env):
|
|
|
|
def __init__(self):
|
2023-08-25 23:40:56 +02:00
|
|
|
self.action_space = gym.spaces.Box(low=-1.0, high=2.0, shape=(4,), dtype=np.float32)
|
2022-10-31 08:54:54 +09:00
|
|
|
self.observation_space = gym.spaces.Discrete(2)
|
|
|
|
|
|
|
|
def step(self, act):
|
|
|
|
return self.observation_space.sample(), -1, False, True, {}
|
2022-06-13 00:18:22 +02:00
|
|
|
|
|
|
|
bsz = 10
|
|
|
|
action_per_branch = [4, 6, 10, 7]
|
|
|
|
env = DummyEnv()
|
|
|
|
original_act = env.action_space.high
|
|
|
|
# convert continous to multidiscrete action space
|
|
|
|
# with different action number per dimension
|
|
|
|
env_m = ContinuousToDiscrete(env, action_per_branch)
|
|
|
|
# check conversion is working properly for one action
|
|
|
|
np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act)
|
|
|
|
# check conversion is working properly for a batch of actions
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
|
2022-07-31 11:00:38 -07:00
|
|
|
np.array([original_act] * bsz),
|
2022-06-13 00:18:22 +02:00
|
|
|
)
|
|
|
|
# convert multidiscrete with different action number per
|
|
|
|
# dimension to discrete action space
|
|
|
|
env_d = MultiDiscreteToDiscrete(env_m)
|
|
|
|
# check conversion is working properly for one action
|
2023-08-25 23:40:56 +02:00
|
|
|
np.testing.assert_allclose(env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1)
|
2022-06-13 00:18:22 +02:00
|
|
|
# check conversion is working properly for a batch of actions
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
|
2022-07-31 11:00:38 -07:00
|
|
|
np.array([env_m.action_space.nvec - 1] * bsz),
|
2022-06-13 00:18:22 +02:00
|
|
|
)
|
2022-10-31 08:54:54 +09:00
|
|
|
# check truncate is True when terminated
|
|
|
|
try:
|
|
|
|
env_t = TruncatedAsTerminated(env)
|
2023-08-25 23:40:56 +02:00
|
|
|
except OSError:
|
2022-10-31 08:54:54 +09:00
|
|
|
env_t = None
|
|
|
|
if env_t is not None:
|
|
|
|
_, _, truncated, _, _ = env_t.step(env_t.action_space.sample())
|
|
|
|
assert truncated
|
2022-06-13 00:18:22 +02:00
|
|
|
|
|
|
|
|
2022-05-05 07:55:15 -04:00
|
|
|
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
|
|
|
def test_venv_wrapper_envpool():
|
2023-02-03 20:57:27 +01:00
|
|
|
raw = envpool.make_gymnasium("Ant-v3", num_envs=4)
|
|
|
|
train = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4))
|
2023-08-25 23:40:56 +02:00
|
|
|
test = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4), update_obs_rms=False)
|
2022-05-05 07:55:15 -04:00
|
|
|
test.set_obs_rms(train.get_obs_rms())
|
2023-08-25 23:40:56 +02:00
|
|
|
actions = [np.array([raw.action_space.sample() for _ in range(4)]) for i in range(30)]
|
2022-05-05 07:55:15 -04:00
|
|
|
run_align_norm_obs(raw, train, test, actions)
|
|
|
|
|
|
|
|
|
2022-06-27 18:52:21 -04:00
|
|
|
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
|
|
|
def test_venv_wrapper_envpool_gym_reset_return_info():
|
|
|
|
num_envs = 4
|
|
|
|
env = VectorEnvNormObs(
|
2023-08-25 23:40:56 +02:00
|
|
|
envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True),
|
2022-06-27 18:52:21 -04:00
|
|
|
)
|
|
|
|
obs, info = env.reset()
|
|
|
|
assert obs.shape[0] == num_envs
|
|
|
|
for _, v in info.items():
|
|
|
|
if not isinstance(v, dict):
|
|
|
|
assert v.shape[0] == num_envs
|
|
|
|
|
|
|
|
|
2022-07-31 11:00:38 -07:00
|
|
|
if __name__ == "__main__":
|
2022-05-05 07:55:15 -04:00
|
|
|
test_venv_norm_obs()
|
2022-07-31 11:00:38 -07:00
|
|
|
test_venv_wrapper_gym()
|
2022-05-05 07:55:15 -04:00
|
|
|
test_venv_wrapper_envpool()
|
2022-07-31 11:00:38 -07:00
|
|
|
test_venv_wrapper_envpool_gym_reset_return_info()
|
2022-05-05 07:55:15 -04:00
|
|
|
test_env_obs_dtype()
|
2020-03-16 11:11:29 +08:00
|
|
|
test_vecenv()
|
2022-07-14 22:52:56 -07:00
|
|
|
test_attr_unwrapped()
|
2020-07-26 12:01:21 +02:00
|
|
|
test_async_env()
|
2020-08-19 15:00:24 +08:00
|
|
|
test_async_check_id()
|
2022-06-27 18:52:21 -04:00
|
|
|
test_env_reset_optional_kwargs()
|
2022-06-13 00:18:22 +02:00
|
|
|
test_gym_wrappers()
|