Tianshou/test/base/test_env.py
Juno T d42a5fb354
Hindsight Experience Replay as a replay buffer (#753)
## implementation
I implemented HER solely as a replay buffer. It is done by temporarily
directly re-writing transitions storage (`self._meta`) during the
`sample_indices()` call. The original transitions are cached and will be
restored at the beginning of the next sampling or when other methods is
called. This will make sure that. for example, n-step return calculation
can be done without altering the policy.

There is also a problem with the original indices sampling. The sampled
indices are not guaranteed to be from different episodes. So I decided
to perform re-writing based on the episode. This guarantees that the
sampled transitions from the same episode will have the same re-written
goal. This also make the re-writing ratio calculation slightly differ
from the paper, but it won't be too different if there are many episodes
in the buffer.

In the current commit, HER replay buffer only support 'future' strategy
and online sampling. This is the best of HER in term of performance and
memory efficiency.

I also add a few more convenient replay buffers
(`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env
(`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a
simple example (examples/offline/fetch_her_ddpg.py).

## verification
I have added unit tests for almost everything I have implemented.
HER replay buffer was also tested using DDPG on [`FetchReach-v3`
env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used
default DDPG parameters from mujoco example and didn't tune anything
further to get this good result! (train script:
examples/offline/fetch_her_ddpg.py).


![Screen Shot 2022-10-02 at 19 22
53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
2022-10-30 16:54:54 -07:00

431 lines
15 KiB
Python

import sys
import time
import gym
import numpy as np
import pytest
from gym.spaces.discrete import Discrete
from tianshou.data import Batch
from tianshou.env import (
ContinuousToDiscrete,
DummyVectorEnv,
MultiDiscreteToDiscrete,
RayVectorEnv,
ShmemVectorEnv,
SubprocVectorEnv,
VectorEnvNormObs,
)
from tianshou.env.gym_wrappers import TruncatedAsTerminated
from tianshou.utils import RunningMeanStd
if __name__ == "__main__":
from env import MyTestEnv, NXEnv
else: # pytest
from test.base.env import MyTestEnv, NXEnv
try:
import envpool
except ImportError:
envpool = None
def has_ray():
try:
import ray # noqa: F401
return True
except ImportError:
return False
def recurse_comp(a, b):
try:
if isinstance(a, np.ndarray):
if a.dtype == object:
return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
else:
return np.allclose(a, b)
elif isinstance(a, (list, tuple)):
return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
elif isinstance(a, dict):
return np.array([recurse_comp(a[k], b[k]) for k in a.keys()]).all()
except (Exception):
return False
def test_async_env(size=10000, num=8, sleep=0.1):
# simplify the test case, just keep stepping
env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
for i in range(size, size + num)
]
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
for cls in test_cls:
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
v.seed(None)
v.reset()
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
# expectation of v is n / (n + 1)
# for a synchronous environment, the following actions should take
# about 7 * sleep * num / (num + 1) seconds
# for async simulation, the analysis is complicated, but the time cost
# should be smaller
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
current_idx_start = 0
act = action_list[:num]
env_ids = list(range(num))
o = []
spent_time = time.time()
while current_idx_start < len(action_list):
A, B, C, D, E, = v.step(action=act, id=env_ids)
b = Batch({"obs": A, "rew": B, "terminate": C, "truncated": D, "info": E})
env_ids = b.info.env_id
o.append(b)
current_idx_start += len(act)
# len of action may be smaller than len(A) in the end
act = action_list[current_idx_start:current_idx_start + len(A)]
# truncate env_ids with the first terms
# typically len(env_ids) == len(A) == len(action), except for the
# last batch when actions are not enough
env_ids = env_ids[:len(act)]
spent_time = time.time() - spent_time
Batch.cat(o)
v.close()
# assure 1/7 improvement
if sys.platform == "linux" and cls != RayVectorEnv:
# macOS/Windows cannot pass this check
assert spent_time < 6.0 * sleep * num / (num + 1)
def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
env_fns = [
lambda: MyTestEnv(size=size, sleep=sleep * 2),
lambda: MyTestEnv(size=size, sleep=sleep * 3),
lambda: MyTestEnv(size=size, sleep=sleep * 5),
lambda: MyTestEnv(size=size, sleep=sleep * 7)
]
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
total_pass = 0
for cls in test_cls:
pass_check = 1
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
t = time.time()
v.reset()
t = time.time() - t
print(f"{cls} reset {t}")
if t > sleep * 9: # huge than maximum sleep time (7 sleep)
pass_check = 0
expect_result = [
[0, 1],
[0, 1, 2],
[0, 1, 3],
[0, 1, 2],
[0, 1],
[0, 2, 3],
[0, 1],
]
ids = np.arange(num)
for res in expect_result:
t = time.time()
_, _, _, _, info = v.step([1] * len(ids), ids)
t = time.time() - t
ids = Batch(info).env_id
print(ids, t)
if not (
len(ids) == len(res) and np.allclose(sorted(ids), res) and
(t < timeout) == (len(res) == num - 1)
):
pass_check = 0
break
total_pass += pass_check
if sys.platform == "linux": # Windows/macOS may not pass this check
assert total_pass >= 2
def test_vecenv(size=10, num=8, sleep=0.001):
env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
for i in range(size, size + num)
]
venv = [
DummyVectorEnv(env_fns),
SubprocVectorEnv(env_fns),
ShmemVectorEnv(env_fns),
]
if has_ray() and sys.platform == "linux":
venv += [RayVectorEnv(env_fns)]
for v in venv:
v.seed(0)
action_list = [1] * 5 + [0] * 10 + [1] * 20
o = [v.reset()[0] for v in venv]
for a in action_list:
o = []
for v in venv:
A, B, C, D, E = v.step([a] * num)
if sum(C + D):
A, _ = v.reset(np.where(C + D)[0])
o.append([A, B, C, D, E])
for index, infos in enumerate(zip(*o)):
if index == 4: # do not check info here
continue
for info in infos:
assert recurse_comp(infos[0], info)
if __name__ == "__main__":
t = [0] * len(venv)
for i, e in enumerate(venv):
t[i] = time.time()
e.reset()
for a in action_list:
done = e.step([a] * num)[2]
if sum(done) > 0:
e.reset(np.where(done)[0])
t[i] = time.time() - t[i]
for i, v in enumerate(venv):
print(f"{type(v)}: {t[i]:.6f}s")
def assert_get(v, expected):
assert v.get_env_attr("size") == expected
assert v.get_env_attr("size", id=0) == [expected[0]]
assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]
for v in venv:
assert_get(v, list(range(size, size + num)))
assert v.env_num == num
assert v.action_space == [Discrete(2)] * num
v.set_env_attr("size", 0)
assert_get(v, [0] * num)
v.set_env_attr("size", 1, 0)
assert_get(v, [1] + [0] * (num - 1))
v.set_env_attr("size", 2, [1, 2, 3])
assert_get(v, [1] + [2] * 3 + [0] * (num - 4))
for v in venv:
v.close()
def test_attr_unwrapped():
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
train_envs.set_env_attr("test_attribute", 1337)
assert train_envs.get_env_attr("test_attribute") == [1337]
assert hasattr(train_envs.workers[0].env, "test_attribute")
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")
def test_env_obs_dtype():
for obs_type in ["array", "object"]:
envs = SubprocVectorEnv(
[lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]
)
obs, info = envs.reset()
assert obs.dtype == object
obs = envs.step([1, 1, 1, 1])[0]
assert obs.dtype == object
def test_env_reset_optional_kwargs(size=10000, num=8):
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
for cls in test_cls:
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
_, info = v.reset(seed=1)
assert len(info) == len(env_fns)
assert isinstance(info[0], dict)
def test_venv_wrapper_gym(num_envs: int = 4):
# Issue 697
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
envs = VectorEnvNormObs(envs)
try:
obs, info = envs.reset()
except ValueError:
obs, info = envs.reset(return_info=True)
assert isinstance(obs, np.ndarray)
assert isinstance(info, list)
assert isinstance(info[0], dict)
assert obs.shape[0] == len(info) == num_envs
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
def reset_result_to_obs(reset_result):
"""Extract observation from reset result
(result is possibly a tuple containing info)"""
if isinstance(reset_result, tuple) and len(reset_result) == 2:
obs, _ = reset_result
else:
obs = reset_result
return obs
eps = np.finfo(np.float32).eps.item()
raw_reset_result = raw_env.reset()
train_reset_result = train_env.reset()
initial_raw_obs = reset_result_to_obs(raw_reset_result)
initial_train_obs = reset_result_to_obs(train_reset_result)
raw_obs, train_obs = [initial_raw_obs], [initial_train_obs]
for action in action_list:
step_result = raw_env.step(action)
if len(step_result) == 5:
obs, rew, terminated, truncated, info = step_result
done = np.logical_or(terminated, truncated)
else:
obs, rew, done, info = step_result
raw_obs.append(obs)
if np.any(done):
reset_result = raw_env.reset(np.where(done)[0])
obs = reset_result_to_obs(reset_result)
raw_obs.append(obs)
step_result = train_env.step(action)
if len(step_result) == 5:
obs, rew, terminated, truncated, info = step_result
done = np.logical_or(terminated, truncated)
else:
obs, rew, done, info = step_result
train_obs.append(obs)
if np.any(done):
reset_result = train_env.reset(np.where(done)[0])
obs = reset_result_to_obs(reset_result)
train_obs.append(obs)
ref_rms = RunningMeanStd()
for ro, to in zip(raw_obs, train_obs):
ref_rms.update(ro)
no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps)
assert np.allclose(no, to)
assert np.allclose(ref_rms.mean, train_env.get_obs_rms().mean)
assert np.allclose(ref_rms.var, train_env.get_obs_rms().var)
assert np.allclose(ref_rms.mean, test_env.get_obs_rms().mean)
assert np.allclose(ref_rms.var, test_env.get_obs_rms().var)
reset_result = test_env.reset()
obs = reset_result_to_obs(reset_result)
test_obs = [obs]
for action in action_list:
step_result = test_env.step(action)
if len(step_result) == 5:
obs, rew, terminated, truncated, info = step_result
done = np.logical_or(terminated, truncated)
else:
obs, rew, done, info = step_result
test_obs.append(obs)
if np.any(done):
reset_result = test_env.reset(np.where(done)[0])
obs = reset_result_to_obs(reset_result)
test_obs.append(obs)
for ro, to in zip(raw_obs, test_obs):
no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps)
assert np.allclose(no, to)
def test_venv_norm_obs():
sizes = np.array([5, 10, 15, 20])
action = np.array([1, 1, 1, 1])
total_step = 30
action_list = [action] * total_step
env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes]
raw = DummyVectorEnv(env_fns)
train_env = VectorEnvNormObs(DummyVectorEnv(env_fns))
print(train_env.observation_space)
test_env = VectorEnvNormObs(DummyVectorEnv(env_fns), update_obs_rms=False)
test_env.set_obs_rms(train_env.get_obs_rms())
run_align_norm_obs(raw, train_env, test_env, action_list)
def test_gym_wrappers():
class DummyEnv(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Box(
low=-1.0, high=2.0, shape=(4, ), dtype=np.float32
)
self.observation_space = gym.spaces.Discrete(2)
def step(self, act):
return self.observation_space.sample(), -1, False, True, {}
bsz = 10
action_per_branch = [4, 6, 10, 7]
env = DummyEnv()
original_act = env.action_space.high
# convert continous to multidiscrete action space
# with different action number per dimension
env_m = ContinuousToDiscrete(env, action_per_branch)
# check conversion is working properly for one action
np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act)
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
np.array([original_act] * bsz),
)
# convert multidiscrete with different action number per
# dimension to discrete action space
env_d = MultiDiscreteToDiscrete(env_m)
# check conversion is working properly for one action
np.testing.assert_allclose(
env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1
)
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
np.array([env_m.action_space.nvec - 1] * bsz),
)
# check truncate is True when terminated
try:
env_t = TruncatedAsTerminated(env)
except EnvironmentError:
env_t = None
if env_t is not None:
_, _, truncated, _, _ = env_t.step(env_t.action_space.sample())
assert truncated
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool():
raw = envpool.make_gym("Ant-v3", num_envs=4)
train = VectorEnvNormObs(envpool.make_gym("Ant-v3", num_envs=4))
test = VectorEnvNormObs(
envpool.make_gym("Ant-v3", num_envs=4), update_obs_rms=False
)
test.set_obs_rms(train.get_obs_rms())
actions = [
np.array([raw.action_space.sample() for _ in range(4)]) for i in range(30)
]
run_align_norm_obs(raw, train, test, actions)
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info():
num_envs = 4
env = VectorEnvNormObs(
envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True)
)
obs, info = env.reset()
assert obs.shape[0] == num_envs
for _, v in info.items():
if not isinstance(v, dict):
assert v.shape[0] == num_envs
if __name__ == "__main__":
test_venv_norm_obs()
test_venv_wrapper_gym()
test_venv_wrapper_envpool()
test_venv_wrapper_envpool_gym_reset_return_info()
test_env_obs_dtype()
test_vecenv()
test_attr_unwrapped()
test_async_env()
test_async_check_id()
test_env_reset_optional_kwargs()
test_gym_wrappers()