2021-02-19 10:33:49 +08:00
|
|
|
import tqdm
|
|
|
|
import pytest
|
2020-03-25 14:08:28 +08:00
|
|
|
import numpy as np
|
2020-04-10 18:02:05 +08:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
from tianshou.policy import BasePolicy
|
2020-08-19 15:00:24 +08:00
|
|
|
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
2021-02-19 10:33:49 +08:00
|
|
|
from tianshou.data import Batch, Collector, AsyncCollector
|
|
|
|
from tianshou.data import (
|
|
|
|
ReplayBuffer,
|
|
|
|
PrioritizedReplayBuffer,
|
|
|
|
VectorReplayBuffer,
|
|
|
|
CachedReplayBuffer,
|
|
|
|
)
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2021-04-25 15:23:46 +08:00
|
|
|
from env import MyTestEnv, NXEnv
|
2020-03-25 14:08:28 +08:00
|
|
|
else: # pytest
|
2021-04-25 15:23:46 +08:00
|
|
|
from test.base.env import MyTestEnv, NXEnv
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
class MyPolicy(BasePolicy):
|
2021-02-19 10:33:49 +08:00
|
|
|
def __init__(self, dict_state=False, need_state=True):
|
2020-08-19 15:00:24 +08:00
|
|
|
"""
|
|
|
|
:param bool dict_state: if the observation of the environment is a dict
|
|
|
|
:param bool need_state: if the policy needs the hidden state (for RNN)
|
|
|
|
"""
|
2020-03-25 14:08:28 +08:00
|
|
|
super().__init__()
|
2020-04-28 20:56:02 +08:00
|
|
|
self.dict_state = dict_state
|
2020-08-19 15:00:24 +08:00
|
|
|
self.need_state = need_state
|
2020-03-25 14:08:28 +08:00
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
def forward(self, batch, state=None):
|
2020-08-19 15:00:24 +08:00
|
|
|
if self.need_state:
|
|
|
|
if state is None:
|
|
|
|
state = np.zeros((len(batch.obs), 2))
|
|
|
|
else:
|
|
|
|
state += 1
|
2020-04-28 20:56:02 +08:00
|
|
|
if self.dict_state:
|
2020-08-19 15:00:24 +08:00
|
|
|
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
|
|
|
|
return Batch(act=np.ones(len(batch.obs)), state=state)
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
def learn(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2020-07-23 16:40:53 +08:00
|
|
|
class Logger:
|
2020-04-10 18:02:05 +08:00
|
|
|
def __init__(self, writer):
|
|
|
|
self.cnt = 0
|
|
|
|
self.writer = writer
|
|
|
|
|
2020-07-23 16:40:53 +08:00
|
|
|
def preprocess_fn(self, **kwargs):
|
|
|
|
# modify info before adding into the buffer, and recorded into tfb
|
2020-08-27 12:15:18 +08:00
|
|
|
# if only obs exist -> reset
|
2021-02-19 10:33:49 +08:00
|
|
|
# if obs_next/rew/done/info exist -> normal step
|
2020-08-27 12:15:18 +08:00
|
|
|
if 'rew' in kwargs:
|
2020-07-23 16:40:53 +08:00
|
|
|
info = kwargs['info']
|
2021-02-19 10:33:49 +08:00
|
|
|
info.rew = kwargs['rew']
|
2020-08-27 12:15:18 +08:00
|
|
|
if 'key' in info.keys():
|
2021-02-19 10:33:49 +08:00
|
|
|
self.writer.add_scalar(
|
|
|
|
'key', np.mean(info.key), global_step=self.cnt)
|
2020-07-23 16:40:53 +08:00
|
|
|
self.cnt += 1
|
|
|
|
return Batch(info=info)
|
|
|
|
else:
|
|
|
|
return Batch()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def single_preprocess_fn(**kwargs):
|
|
|
|
# same as above, without tfb
|
2020-08-27 12:15:18 +08:00
|
|
|
if 'rew' in kwargs:
|
2020-07-23 16:40:53 +08:00
|
|
|
info = kwargs['info']
|
2021-02-19 10:33:49 +08:00
|
|
|
info.rew = kwargs['rew']
|
2020-07-23 16:40:53 +08:00
|
|
|
return Batch(info=info)
|
|
|
|
else:
|
|
|
|
return Batch()
|
2020-04-10 18:02:05 +08:00
|
|
|
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
def test_collector():
|
2020-04-10 18:02:05 +08:00
|
|
|
writer = SummaryWriter('log/collector')
|
|
|
|
logger = Logger(writer)
|
2020-06-09 18:46:14 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
2020-04-10 18:02:05 +08:00
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
venv = SubprocVectorEnv(env_fns)
|
2020-08-19 15:00:24 +08:00
|
|
|
dum = DummyVectorEnv(env_fns)
|
2020-03-25 14:08:28 +08:00
|
|
|
policy = MyPolicy()
|
|
|
|
env = env_fns[0]()
|
2021-02-19 10:33:49 +08:00
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
|
2020-07-23 16:40:53 +08:00
|
|
|
c0.collect(n_step=3)
|
2021-02-19 10:33:49 +08:00
|
|
|
assert len(c0.buffer) == 3
|
|
|
|
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
|
|
|
|
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
|
2020-07-23 16:40:53 +08:00
|
|
|
c0.collect(n_episode=3)
|
2021-02-19 10:33:49 +08:00
|
|
|
assert len(c0.buffer) == 8
|
|
|
|
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
|
|
|
|
assert np.allclose(c0.buffer[:].obs_next[..., 0],
|
|
|
|
[1, 2, 1, 2, 1, 2, 1, 2])
|
2020-06-11 08:57:37 +08:00
|
|
|
c0.collect(n_step=3, random=True)
|
2021-02-19 10:33:49 +08:00
|
|
|
c1 = Collector(
|
|
|
|
policy, venv,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4),
|
|
|
|
logger.preprocess_fn)
|
|
|
|
c1.collect(n_step=8)
|
|
|
|
obs = np.zeros(100)
|
|
|
|
obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]
|
2020-03-25 14:08:28 +08:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
assert np.allclose(c1.buffer.obs[:, 0], obs)
|
|
|
|
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
|
|
|
|
c1.collect(n_episode=4)
|
|
|
|
assert len(c1.buffer) == 16
|
|
|
|
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
|
|
|
|
assert np.allclose(c1.buffer.obs[:, 0], obs)
|
|
|
|
assert np.allclose(c1.buffer[:].obs_next[..., 0],
|
|
|
|
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
|
|
|
c1.collect(n_episode=4, random=True)
|
|
|
|
c2 = Collector(
|
|
|
|
policy, dum,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4),
|
|
|
|
logger.preprocess_fn)
|
|
|
|
c2.collect(n_episode=7)
|
|
|
|
obs1 = obs.copy()
|
|
|
|
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
|
|
|
|
obs2 = obs.copy()
|
|
|
|
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
|
|
|
|
c2obs = c2.buffer.obs[:, 0]
|
|
|
|
assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
|
|
|
|
c2.reset_env()
|
|
|
|
c2.reset_buffer()
|
|
|
|
assert c2.collect(n_episode=8)['n/ep'] == 8
|
|
|
|
obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
|
|
|
|
assert np.all(c2.buffer.obs[:, 0] == obs)
|
|
|
|
c2.collect(n_episode=4, random=True)
|
2020-08-27 12:15:18 +08:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
# test corner case
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
Collector(policy, dum, ReplayBuffer(10))
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
c2.collect()
|
2020-08-27 12:15:18 +08:00
|
|
|
|
2021-04-25 15:23:46 +08:00
|
|
|
# test NXEnv
|
|
|
|
for obs_type in ["array", "object"]:
|
|
|
|
envs = SubprocVectorEnv([
|
|
|
|
lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]])
|
|
|
|
c3 = Collector(policy, envs,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4))
|
|
|
|
c3.collect(n_step=6)
|
|
|
|
assert c3.buffer.obs.dtype == object
|
|
|
|
|
2020-08-27 12:15:18 +08:00
|
|
|
|
2020-07-26 12:01:21 +02:00
|
|
|
def test_collector_with_async():
|
|
|
|
env_lens = [2, 3, 4, 5]
|
|
|
|
writer = SummaryWriter('log/async_collector')
|
|
|
|
logger = Logger(writer)
|
2021-02-19 10:33:49 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True)
|
2020-07-26 12:01:21 +02:00
|
|
|
for i in env_lens]
|
|
|
|
|
2020-08-19 15:00:24 +08:00
|
|
|
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
|
2020-07-26 12:01:21 +02:00
|
|
|
policy = MyPolicy()
|
2021-02-19 10:33:49 +08:00
|
|
|
bufsize = 60
|
|
|
|
c1 = AsyncCollector(
|
|
|
|
policy, venv,
|
|
|
|
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
|
|
|
|
logger.preprocess_fn)
|
|
|
|
ptr = [0, 0, 0, 0]
|
|
|
|
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
|
|
|
|
result = c1.collect(n_episode=n_episode)
|
|
|
|
assert result["n/ep"] >= n_episode
|
|
|
|
# check buffer data, obs and obs_next, env_id
|
|
|
|
for i, count in enumerate(
|
|
|
|
np.bincount(result["lens"], minlength=6)[2:]):
|
|
|
|
env_len = i + 2
|
|
|
|
total = env_len * count
|
|
|
|
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
|
|
|
|
ptr[i] = (ptr[i] + total) % bufsize
|
|
|
|
seq = np.arange(env_len)
|
|
|
|
buf = c1.buffer.buffers[i]
|
|
|
|
assert np.all(buf.info.env_id[indices] == i)
|
|
|
|
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
|
|
|
|
assert np.all(buf.obs_next[indices].reshape(
|
|
|
|
count, env_len) == seq + 1)
|
|
|
|
# test async n_step, for now the buffer should be full of data
|
|
|
|
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
|
|
|
|
result = c1.collect(n_step=n_step)
|
|
|
|
assert result["n/st"] >= n_step
|
|
|
|
for i in range(4):
|
|
|
|
env_len = i + 2
|
|
|
|
seq = np.arange(env_len)
|
|
|
|
buf = c1.buffer.buffers[i]
|
|
|
|
assert np.all(buf.info.env_id == i)
|
|
|
|
assert np.all(buf.obs.reshape(-1, env_len) == seq)
|
|
|
|
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
c1.collect()
|
2020-07-26 12:01:21 +02:00
|
|
|
|
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
def test_collector_with_dict_state():
|
|
|
|
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
|
|
|
policy = MyPolicy(dict_state=True)
|
2020-07-23 16:40:53 +08:00
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100),
|
|
|
|
Logger.single_preprocess_fn)
|
2020-04-28 20:56:02 +08:00
|
|
|
c0.collect(n_step=3)
|
2020-07-23 16:40:53 +08:00
|
|
|
c0.collect(n_episode=2)
|
2021-02-19 10:33:49 +08:00
|
|
|
assert len(c0.buffer) == 10
|
2020-06-09 18:46:14 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
|
|
|
for i in [2, 3, 4, 5]]
|
2020-08-19 15:00:24 +08:00
|
|
|
envs = DummyVectorEnv(env_fns)
|
2020-07-24 17:38:12 +08:00
|
|
|
envs.seed(666)
|
|
|
|
obs = envs.reset()
|
|
|
|
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
|
2021-02-19 10:33:49 +08:00
|
|
|
c1 = Collector(
|
|
|
|
policy, envs,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4),
|
|
|
|
Logger.single_preprocess_fn)
|
|
|
|
c1.collect(n_step=12)
|
|
|
|
result = c1.collect(n_episode=8)
|
|
|
|
assert result['n/ep'] == 8
|
|
|
|
lens = np.bincount(result['lens'])
|
|
|
|
assert result['n/st'] == 21 and np.all(lens == [0, 0, 2, 2, 2, 2]) or \
|
|
|
|
result['n/st'] == 20 and np.all(lens == [0, 0, 3, 1, 2, 2])
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c1.buffer.sample(10)
|
2020-04-29 12:14:53 +08:00
|
|
|
c0.buffer.update(c1.buffer)
|
2021-02-19 10:33:49 +08:00
|
|
|
assert len(c0.buffer) in [42, 43]
|
|
|
|
if len(c0.buffer) == 42:
|
|
|
|
assert np.all(c0.buffer[:].obs.index[..., 0] == [
|
|
|
|
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
|
|
|
|
0, 1, 0, 1, 0, 1, 0, 1,
|
|
|
|
0, 1, 2, 0, 1, 2,
|
|
|
|
0, 1, 2, 3, 0, 1, 2, 3,
|
|
|
|
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
|
|
|
|
]), c0.buffer[:].obs.index[..., 0]
|
|
|
|
else:
|
|
|
|
assert np.all(c0.buffer[:].obs.index[..., 0] == [
|
|
|
|
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
|
|
|
|
0, 1, 0, 1, 0, 1,
|
|
|
|
0, 1, 2, 0, 1, 2, 0, 1, 2,
|
|
|
|
0, 1, 2, 3, 0, 1, 2, 3,
|
|
|
|
0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
|
|
|
|
]), c0.buffer[:].obs.index[..., 0]
|
|
|
|
c2 = Collector(
|
|
|
|
policy, envs,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
|
|
|
|
Logger.single_preprocess_fn)
|
|
|
|
c2.collect(n_episode=10)
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c2.buffer.sample(10)
|
2020-04-28 20:56:02 +08:00
|
|
|
|
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
def test_collector_with_ma():
|
|
|
|
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
|
|
|
|
policy = MyPolicy()
|
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100),
|
2021-02-19 10:33:49 +08:00
|
|
|
Logger.single_preprocess_fn)
|
2020-07-23 16:40:53 +08:00
|
|
|
# n_step=3 will collect a full episode
|
2021-02-19 10:33:49 +08:00
|
|
|
r = c0.collect(n_step=3)['rews']
|
|
|
|
assert len(r) == 0
|
|
|
|
r = c0.collect(n_episode=2)['rews']
|
|
|
|
assert r.shape == (2, 4) and np.all(r == 1)
|
2020-07-13 00:24:31 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
|
|
|
|
for i in [2, 3, 4, 5]]
|
2020-08-19 15:00:24 +08:00
|
|
|
envs = DummyVectorEnv(env_fns)
|
2021-02-19 10:33:49 +08:00
|
|
|
c1 = Collector(
|
|
|
|
policy, envs,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4),
|
|
|
|
Logger.single_preprocess_fn)
|
|
|
|
r = c1.collect(n_step=12)['rews']
|
|
|
|
assert r.shape == (2, 4) and np.all(r == 1), r
|
|
|
|
r = c1.collect(n_episode=8)['rews']
|
|
|
|
assert r.shape == (8, 4) and np.all(r == 1)
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c1.buffer.sample(10)
|
2020-07-13 00:24:31 +08:00
|
|
|
print(batch)
|
|
|
|
c0.buffer.update(c1.buffer)
|
2021-02-19 10:33:49 +08:00
|
|
|
assert len(c0.buffer) in [42, 43]
|
|
|
|
if len(c0.buffer) == 42:
|
|
|
|
rew = [
|
|
|
|
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
|
|
|
|
0, 1, 0, 1, 0, 1, 0, 1,
|
|
|
|
0, 0, 1, 0, 0, 1,
|
|
|
|
0, 0, 0, 1, 0, 0, 0, 1,
|
|
|
|
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
rew = [
|
|
|
|
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
|
|
|
|
0, 1, 0, 1, 0, 1,
|
|
|
|
0, 0, 1, 0, 0, 1, 0, 0, 1,
|
|
|
|
0, 0, 0, 1, 0, 0, 0, 1,
|
|
|
|
0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
|
|
|
|
]
|
|
|
|
assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew])
|
|
|
|
assert np.all(c0.buffer[:].done == rew)
|
|
|
|
c2 = Collector(
|
|
|
|
policy, envs,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
|
|
|
|
Logger.single_preprocess_fn)
|
|
|
|
r = c2.collect(n_episode=10)['rews']
|
|
|
|
assert r.shape == (10, 4) and np.all(r == 1)
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c2.buffer.sample(10)
|
2020-07-13 00:24:31 +08:00
|
|
|
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
def test_collector_with_atari_setting():
|
|
|
|
reference_obs = np.zeros([6, 4, 84, 84])
|
|
|
|
for i in range(6):
|
|
|
|
reference_obs[i, 3, np.arange(84), np.arange(84)] = i
|
|
|
|
reference_obs[i, 2, np.arange(84)] = i
|
|
|
|
reference_obs[i, 1, :, np.arange(84)] = i
|
|
|
|
reference_obs[i, 0] = i
|
|
|
|
|
|
|
|
# atari single buffer
|
|
|
|
env = MyTestEnv(size=5, sleep=0, array_state=True)
|
|
|
|
policy = MyPolicy()
|
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100))
|
|
|
|
c0.collect(n_step=6)
|
|
|
|
c0.collect(n_episode=2)
|
|
|
|
assert c0.buffer.obs.shape == (100, 4, 84, 84)
|
|
|
|
assert c0.buffer.obs_next.shape == (100, 4, 84, 84)
|
|
|
|
assert len(c0.buffer) == 15
|
|
|
|
obs = np.zeros_like(c0.buffer.obs)
|
|
|
|
obs[np.arange(15)] = reference_obs[np.arange(15) % 5]
|
|
|
|
assert np.all(obs == c0.buffer.obs)
|
|
|
|
|
|
|
|
c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True))
|
|
|
|
c1.collect(n_episode=3)
|
|
|
|
assert np.allclose(c0.buffer.obs, c1.buffer.obs)
|
|
|
|
with pytest.raises(AttributeError):
|
|
|
|
c1.buffer.obs_next
|
|
|
|
assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next)
|
|
|
|
|
|
|
|
c2 = Collector(
|
|
|
|
policy, env,
|
|
|
|
ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True))
|
|
|
|
c2.collect(n_step=8)
|
|
|
|
assert c2.buffer.obs.shape == (100, 84, 84)
|
|
|
|
obs = np.zeros_like(c2.buffer.obs)
|
|
|
|
obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1]
|
|
|
|
assert np.all(c2.buffer.obs == obs)
|
|
|
|
assert np.allclose(c2.buffer[:].obs_next,
|
|
|
|
reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1])
|
|
|
|
|
|
|
|
# atari multi buffer
|
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True)
|
|
|
|
for i in [2, 3, 4, 5]]
|
|
|
|
envs = DummyVectorEnv(env_fns)
|
|
|
|
c3 = Collector(
|
|
|
|
policy, envs,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4))
|
|
|
|
c3.collect(n_step=12)
|
|
|
|
result = c3.collect(n_episode=9)
|
|
|
|
assert result["n/ep"] == 9 and result["n/st"] == 23
|
|
|
|
assert c3.buffer.obs.shape == (100, 4, 84, 84)
|
|
|
|
obs = np.zeros_like(c3.buffer.obs)
|
|
|
|
obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
|
|
|
|
obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
|
|
|
|
obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
|
|
|
|
obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
|
|
|
|
assert np.all(obs == c3.buffer.obs)
|
|
|
|
obs_next = np.zeros_like(c3.buffer.obs_next)
|
|
|
|
obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]]
|
|
|
|
obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]]
|
|
|
|
obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]]
|
|
|
|
obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]]
|
|
|
|
assert np.all(obs_next == c3.buffer.obs_next)
|
|
|
|
c4 = Collector(
|
|
|
|
policy, envs,
|
|
|
|
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4,
|
|
|
|
ignore_obs_next=True, save_only_last_obs=True))
|
|
|
|
c4.collect(n_step=12)
|
|
|
|
result = c4.collect(n_episode=9)
|
|
|
|
assert result["n/ep"] == 9 and result["n/st"] == 23
|
|
|
|
assert c4.buffer.obs.shape == (100, 84, 84)
|
|
|
|
obs = np.zeros_like(c4.buffer.obs)
|
|
|
|
slice_obs = reference_obs[:, -1]
|
|
|
|
obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]]
|
|
|
|
obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]]
|
|
|
|
obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]]
|
|
|
|
obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
|
|
|
|
assert np.all(c4.buffer.obs == obs)
|
|
|
|
obs_next = np.zeros([len(c4.buffer), 4, 84, 84])
|
|
|
|
ref_index = np.array([
|
|
|
|
1, 1, 1, 1, 1, 1, 1, 1,
|
|
|
|
1, 2, 2, 1, 2, 2, 1, 2, 2,
|
|
|
|
1, 2, 3, 3, 1, 2, 3, 3,
|
|
|
|
1, 2, 3, 4, 4, 1, 2, 3, 4, 4,
|
|
|
|
])
|
|
|
|
obs_next[:, -1] = slice_obs[ref_index]
|
|
|
|
ref_index -= 1
|
|
|
|
ref_index[ref_index < 0] = 0
|
|
|
|
obs_next[:, -2] = slice_obs[ref_index]
|
|
|
|
ref_index -= 1
|
|
|
|
ref_index[ref_index < 0] = 0
|
|
|
|
obs_next[:, -3] = slice_obs[ref_index]
|
|
|
|
ref_index -= 1
|
|
|
|
ref_index[ref_index < 0] = 0
|
|
|
|
obs_next[:, -4] = slice_obs[ref_index]
|
|
|
|
assert np.all(obs_next == c4.buffer[:].obs_next)
|
|
|
|
|
|
|
|
buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True,
|
|
|
|
save_only_last_obs=True)
|
|
|
|
c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
|
|
|
|
result_ = c5.collect(n_step=12)
|
|
|
|
assert len(buf) == 5 and len(c5.buffer) == 12
|
|
|
|
result = c5.collect(n_episode=9)
|
|
|
|
assert result["n/ep"] == 9 and result["n/st"] == 23
|
|
|
|
assert len(buf) == 35
|
|
|
|
assert np.all(buf.obs[:len(buf)] == slice_obs[[
|
|
|
|
0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4,
|
|
|
|
0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4]])
|
|
|
|
assert np.all(buf[:].obs_next[:, -1] == slice_obs[[
|
|
|
|
1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4,
|
|
|
|
1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]])
|
|
|
|
assert len(buf) == len(c5.buffer)
|
|
|
|
|
|
|
|
# test buffer=None
|
|
|
|
c6 = Collector(policy, envs)
|
|
|
|
result1 = c6.collect(n_step=12)
|
|
|
|
for key in ["n/ep", "n/st", "rews", "lens"]:
|
|
|
|
assert np.allclose(result1[key], result_[key])
|
|
|
|
result2 = c6.collect(n_episode=9)
|
|
|
|
for key in ["n/ep", "n/st", "rews", "lens"]:
|
|
|
|
assert np.allclose(result2[key], result[key])
|
|
|
|
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
test_collector()
|
2020-04-28 20:56:02 +08:00
|
|
|
test_collector_with_dict_state()
|
2020-07-13 00:24:31 +08:00
|
|
|
test_collector_with_ma()
|
2021-02-19 10:33:49 +08:00
|
|
|
test_collector_with_atari_setting()
|
2020-07-26 12:01:21 +02:00
|
|
|
test_collector_with_async()
|