Tianshou/test/base/test_collector.py

427 lines
16 KiB
Python
Raw Normal View History

import tqdm
import pytest
2020-03-25 14:08:28 +08:00
import numpy as np
2020-04-10 18:02:05 +08:00
from torch.utils.tensorboard import SummaryWriter
2020-03-25 14:08:28 +08:00
from tianshou.policy import BasePolicy
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.data import Batch, Collector, AsyncCollector
from tianshou.data import (
ReplayBuffer,
PrioritizedReplayBuffer,
VectorReplayBuffer,
CachedReplayBuffer,
)
2020-03-25 14:08:28 +08:00
if __name__ == '__main__':
from env import MyTestEnv, NXEnv
2020-03-25 14:08:28 +08:00
else: # pytest
from test.base.env import MyTestEnv, NXEnv
2020-03-25 14:08:28 +08:00
class MyPolicy(BasePolicy):
def __init__(self, dict_state=False, need_state=True):
"""
:param bool dict_state: if the observation of the environment is a dict
:param bool need_state: if the policy needs the hidden state (for RNN)
"""
2020-03-25 14:08:28 +08:00
super().__init__()
2020-04-28 20:56:02 +08:00
self.dict_state = dict_state
self.need_state = need_state
2020-03-25 14:08:28 +08:00
2020-04-10 10:47:16 +08:00
def forward(self, batch, state=None):
if self.need_state:
if state is None:
state = np.zeros((len(batch.obs), 2))
else:
state += 1
2020-04-28 20:56:02 +08:00
if self.dict_state:
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
return Batch(act=np.ones(len(batch.obs)), state=state)
2020-03-25 14:08:28 +08:00
def learn(self):
pass
class Logger:
2020-04-10 18:02:05 +08:00
def __init__(self, writer):
self.cnt = 0
self.writer = writer
def preprocess_fn(self, **kwargs):
# modify info before adding into the buffer, and recorded into tfb
# if only obs exist -> reset
# if obs_next/rew/done/info exist -> normal step
if 'rew' in kwargs:
info = kwargs['info']
info.rew = kwargs['rew']
if 'key' in info.keys():
self.writer.add_scalar(
'key', np.mean(info.key), global_step=self.cnt)
self.cnt += 1
return Batch(info=info)
else:
return Batch()
@staticmethod
def single_preprocess_fn(**kwargs):
# same as above, without tfb
if 'rew' in kwargs:
info = kwargs['info']
info.rew = kwargs['rew']
return Batch(info=info)
else:
return Batch()
2020-04-10 18:02:05 +08:00
2020-03-25 14:08:28 +08:00
def test_collector():
2020-04-10 18:02:05 +08:00
writer = SummaryWriter('log/collector')
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
2020-04-10 18:02:05 +08:00
2020-03-25 14:08:28 +08:00
venv = SubprocVectorEnv(env_fns)
dum = DummyVectorEnv(env_fns)
2020-03-25 14:08:28 +08:00
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
c0.collect(n_step=3)
assert len(c0.buffer) == 3
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
c0.collect(n_episode=3)
assert len(c0.buffer) == 8
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0],
[1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True)
c1 = Collector(
policy, venv,
VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn)
c1.collect(n_step=8)
obs = np.zeros(100)
obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]
2020-03-25 14:08:28 +08:00
assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c1.collect(n_episode=4)
assert len(c1.buffer) == 16
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(c1.buffer[:].obs_next[..., 0],
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
c1.collect(n_episode=4, random=True)
c2 = Collector(
policy, dum,
VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn)
c2.collect(n_episode=7)
obs1 = obs.copy()
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
obs2 = obs.copy()
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
c2obs = c2.buffer.obs[:, 0]
assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
c2.reset_env()
c2.reset_buffer()
assert c2.collect(n_episode=8)['n/ep'] == 8
obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert np.all(c2.buffer.obs[:, 0] == obs)
c2.collect(n_episode=4, random=True)
# test corner case
with pytest.raises(TypeError):
Collector(policy, dum, ReplayBuffer(10))
with pytest.raises(TypeError):
Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError):
c2.collect()
# test NXEnv
for obs_type in ["array", "object"]:
envs = SubprocVectorEnv([
lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]])
c3 = Collector(policy, envs,
VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=6)
assert c3.buffer.obs.dtype == object
def test_collector_with_async():
env_lens = [2, 3, 4, 5]
writer = SummaryWriter('log/async_collector')
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True)
for i in env_lens]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy()
bufsize = 60
c1 = AsyncCollector(
policy, venv,
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn)
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(
np.bincount(result["lens"], minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(
count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step)
assert result["n/st"] >= n_step
for i in range(4):
env_len = i + 2
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id == i)
assert np.all(buf.obs.reshape(-1, env_len) == seq)
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
with pytest.raises(TypeError):
c1.collect()
2020-04-28 20:56:02 +08:00
def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
2020-04-28 20:56:02 +08:00
c0.collect(n_step=3)
c0.collect(n_episode=2)
assert len(c0.buffer) == 10
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
envs.seed(666)
obs = envs.reset()
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
c1 = Collector(
policy, envs,
VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn)
c1.collect(n_step=12)
result = c1.collect(n_episode=8)
assert result['n/ep'] == 8
lens = np.bincount(result['lens'])
assert result['n/st'] == 21 and np.all(lens == [0, 0, 2, 2, 2, 2]) or \
result['n/st'] == 20 and np.all(lens == [0, 0, 3, 1, 2, 2])
batch, _ = c1.buffer.sample(10)
c0.buffer.update(c1.buffer)
assert len(c0.buffer) in [42, 43]
if len(c0.buffer) == 42:
assert np.all(c0.buffer[:].obs.index[..., 0] == [
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 2, 0, 1, 2,
0, 1, 2, 3, 0, 1, 2, 3,
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
]), c0.buffer[:].obs.index[..., 0]
else:
assert np.all(c0.buffer[:].obs.index[..., 0] == [
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
0, 1, 0, 1, 0, 1,
0, 1, 2, 0, 1, 2, 0, 1, 2,
0, 1, 2, 3, 0, 1, 2, 3,
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
]), c0.buffer[:].obs.index[..., 0]
c2 = Collector(
policy, envs,
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn)
c2.collect(n_episode=10)
batch, _ = c2.buffer.sample(10)
2020-04-28 20:56:02 +08:00
def test_collector_with_ma():
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
# n_step=3 will collect a full episode
r = c0.collect(n_step=3)['rews']
assert len(r) == 0
r = c0.collect(n_episode=2)['rews']
assert r.shape == (2, 4) and np.all(r == 1)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
c1 = Collector(
policy, envs,
VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn)
r = c1.collect(n_step=12)['rews']
assert r.shape == (2, 4) and np.all(r == 1), r
r = c1.collect(n_episode=8)['rews']
assert r.shape == (8, 4) and np.all(r == 1)
batch, _ = c1.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
assert len(c0.buffer) in [42, 43]
if len(c0.buffer) == 42:
rew = [
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1,
0, 0, 1, 0, 0, 1,
0, 0, 0, 1, 0, 0, 0, 1,
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
]
else:
rew = [
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
0, 1, 0, 1, 0, 1,
0, 0, 1, 0, 0, 1, 0, 0, 1,
0, 0, 0, 1, 0, 0, 0, 1,
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
]
assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew])
assert np.all(c0.buffer[:].done == rew)
c2 = Collector(
policy, envs,
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn)
r = c2.collect(n_episode=10)['rews']
assert r.shape == (10, 4) and np.all(r == 1)
batch, _ = c2.buffer.sample(10)
def test_collector_with_atari_setting():
reference_obs = np.zeros([6, 4, 84, 84])
for i in range(6):
reference_obs[i, 3, np.arange(84), np.arange(84)] = i
reference_obs[i, 2, np.arange(84)] = i
reference_obs[i, 1, :, np.arange(84)] = i
reference_obs[i, 0] = i
# atari single buffer
env = MyTestEnv(size=5, sleep=0, array_state=True)
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100))
c0.collect(n_step=6)
c0.collect(n_episode=2)
assert c0.buffer.obs.shape == (100, 4, 84, 84)
assert c0.buffer.obs_next.shape == (100, 4, 84, 84)
assert len(c0.buffer) == 15
obs = np.zeros_like(c0.buffer.obs)
obs[np.arange(15)] = reference_obs[np.arange(15) % 5]
assert np.all(obs == c0.buffer.obs)
c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True))
c1.collect(n_episode=3)
assert np.allclose(c0.buffer.obs, c1.buffer.obs)
with pytest.raises(AttributeError):
c1.buffer.obs_next
assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next)
c2 = Collector(
policy, env,
ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True))
c2.collect(n_step=8)
assert c2.buffer.obs.shape == (100, 84, 84)
obs = np.zeros_like(c2.buffer.obs)
obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1]
assert np.all(c2.buffer.obs == obs)
assert np.allclose(c2.buffer[:].obs_next,
reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1])
# atari multi buffer
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True)
for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
c3 = Collector(
policy, envs,
VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=12)
result = c3.collect(n_episode=9)
assert result["n/ep"] == 9 and result["n/st"] == 23
assert c3.buffer.obs.shape == (100, 4, 84, 84)
obs = np.zeros_like(c3.buffer.obs)
obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
assert np.all(obs == c3.buffer.obs)
obs_next = np.zeros_like(c3.buffer.obs_next)
obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]]
obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]]
obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]]
obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]]
assert np.all(obs_next == c3.buffer.obs_next)
c4 = Collector(
policy, envs,
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4,
ignore_obs_next=True, save_only_last_obs=True))
c4.collect(n_step=12)
result = c4.collect(n_episode=9)
assert result["n/ep"] == 9 and result["n/st"] == 23
assert c4.buffer.obs.shape == (100, 84, 84)
obs = np.zeros_like(c4.buffer.obs)
slice_obs = reference_obs[:, -1]
obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
assert np.all(c4.buffer.obs == obs)
obs_next = np.zeros([len(c4.buffer), 4, 84, 84])
ref_index = np.array([
1, 1, 1, 1, 1, 1, 1, 1,
1, 2, 2, 1, 2, 2, 1, 2, 2,
1, 2, 3, 3, 1, 2, 3, 3,
1, 2, 3, 4, 4, 1, 2, 3, 4, 4,
])
obs_next[:, -1] = slice_obs[ref_index]
ref_index -= 1
ref_index[ref_index < 0] = 0
obs_next[:, -2] = slice_obs[ref_index]
ref_index -= 1
ref_index[ref_index < 0] = 0
obs_next[:, -3] = slice_obs[ref_index]
ref_index -= 1
ref_index[ref_index < 0] = 0
obs_next[:, -4] = slice_obs[ref_index]
assert np.all(obs_next == c4.buffer[:].obs_next)
buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True,
save_only_last_obs=True)
c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
result_ = c5.collect(n_step=12)
assert len(buf) == 5 and len(c5.buffer) == 12
result = c5.collect(n_episode=9)
assert result["n/ep"] == 9 and result["n/st"] == 23
assert len(buf) == 35
assert np.all(buf.obs[:len(buf)] == slice_obs[[
0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4,
0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4]])
assert np.all(buf[:].obs_next[:, -1] == slice_obs[[
1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4,
1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]])
assert len(buf) == len(c5.buffer)
# test buffer=None
c6 = Collector(policy, envs)
result1 = c6.collect(n_step=12)
for key in ["n/ep", "n/st", "rews", "lens"]:
assert np.allclose(result1[key], result_[key])
result2 = c6.collect(n_episode=9)
for key in ["n/ep", "n/st", "rews", "lens"]:
assert np.allclose(result2[key], result[key])
2020-03-25 14:08:28 +08:00
if __name__ == '__main__':
test_collector()
2020-04-28 20:56:02 +08:00
test_collector_with_dict_state()
test_collector_with_ma()
test_collector_with_atari_setting()
test_collector_with_async()