2020-08-02 18:24:40 +08:00
|
|
|
import numpy as np
|
2021-09-03 05:05:04 +08:00
|
|
|
import tqdm
|
2020-08-02 18:24:40 +08:00
|
|
|
|
2021-09-03 05:05:04 +08:00
|
|
|
from tianshou.data import AsyncCollector, Batch, Collector, VectorReplayBuffer
|
2021-02-19 10:33:49 +08:00
|
|
|
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
2021-09-03 05:05:04 +08:00
|
|
|
from tianshou.policy import BasePolicy
|
2020-08-02 18:24:40 +08:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
from env import MyTestEnv
|
|
|
|
else: # pytest
|
|
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
|
|
|
|
|
|
class MyPolicy(BasePolicy):
|
2021-09-03 05:05:04 +08:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
def __init__(self, dict_state=False, need_state=True):
|
|
|
|
"""
|
|
|
|
:param bool dict_state: if the observation of the environment is a dict
|
|
|
|
:param bool need_state: if the policy needs the hidden state (for RNN)
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.dict_state = dict_state
|
|
|
|
self.need_state = need_state
|
|
|
|
|
|
|
|
def forward(self, batch, state=None):
|
|
|
|
if self.need_state:
|
|
|
|
if state is None:
|
|
|
|
state = np.zeros((len(batch.obs), 2))
|
|
|
|
else:
|
|
|
|
state += 1
|
|
|
|
if self.dict_state:
|
|
|
|
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
|
|
|
|
return Batch(act=np.ones(len(batch.obs)), state=state)
|
|
|
|
|
|
|
|
def learn(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def test_collector_nstep():
|
|
|
|
policy = MyPolicy()
|
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
|
|
|
|
dum = DummyVectorEnv(env_fns)
|
|
|
|
num = len(env_fns)
|
2021-09-03 05:05:04 +08:00
|
|
|
c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
|
2021-02-19 10:33:49 +08:00
|
|
|
for i in tqdm.trange(1, 400, desc="test step collector n_step"):
|
|
|
|
c3.reset()
|
|
|
|
result = c3.collect(n_step=i * len(env_fns))
|
|
|
|
assert result['n/st'] >= i
|
|
|
|
|
|
|
|
|
|
|
|
def test_collector_nepisode():
|
|
|
|
policy = MyPolicy()
|
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
|
|
|
|
dum = DummyVectorEnv(env_fns)
|
|
|
|
num = len(env_fns)
|
2021-09-03 05:05:04 +08:00
|
|
|
c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
|
2021-02-19 10:33:49 +08:00
|
|
|
for i in tqdm.trange(1, 400, desc="test step collector n_episode"):
|
|
|
|
c3.reset()
|
|
|
|
result = c3.collect(n_episode=i)
|
|
|
|
assert result['n/ep'] == i
|
|
|
|
assert result['n/st'] == len(c3.buffer)
|
|
|
|
|
|
|
|
|
|
|
|
def test_asynccollector():
|
|
|
|
env_lens = [2, 3, 4, 5]
|
2021-09-03 05:05:04 +08:00
|
|
|
env_fns = [
|
|
|
|
lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens
|
|
|
|
]
|
2021-02-19 10:33:49 +08:00
|
|
|
|
|
|
|
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
|
|
|
|
policy = MyPolicy()
|
|
|
|
bufsize = 300
|
|
|
|
c1 = AsyncCollector(
|
2021-09-03 05:05:04 +08:00
|
|
|
policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)
|
|
|
|
)
|
2021-02-19 10:33:49 +08:00
|
|
|
ptr = [0, 0, 0, 0]
|
|
|
|
for n_episode in tqdm.trange(1, 100, desc="test async n_episode"):
|
|
|
|
result = c1.collect(n_episode=n_episode)
|
|
|
|
assert result["n/ep"] >= n_episode
|
|
|
|
# check buffer data, obs and obs_next, env_id
|
2021-09-03 05:05:04 +08:00
|
|
|
for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
|
2021-02-19 10:33:49 +08:00
|
|
|
env_len = i + 2
|
|
|
|
total = env_len * count
|
|
|
|
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
|
|
|
|
ptr[i] = (ptr[i] + total) % bufsize
|
|
|
|
seq = np.arange(env_len)
|
|
|
|
buf = c1.buffer.buffers[i]
|
|
|
|
assert np.all(buf.info.env_id[indices] == i)
|
|
|
|
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
|
2021-09-03 05:05:04 +08:00
|
|
|
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
|
2021-02-19 10:33:49 +08:00
|
|
|
# test async n_step, for now the buffer should be full of data
|
|
|
|
for n_step in tqdm.trange(1, 150, desc="test async n_step"):
|
|
|
|
result = c1.collect(n_step=n_step)
|
|
|
|
assert result["n/st"] >= n_step
|
|
|
|
for i in range(4):
|
|
|
|
env_len = i + 2
|
|
|
|
seq = np.arange(env_len)
|
|
|
|
buf = c1.buffer.buffers[i]
|
|
|
|
assert np.all(buf.info.env_id == i)
|
|
|
|
assert np.all(buf.obs.reshape(-1, env_len) == seq)
|
|
|
|
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
|
2020-08-02 18:24:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2021-02-19 10:33:49 +08:00
|
|
|
test_collector_nstep()
|
|
|
|
test_collector_nepisode()
|
|
|
|
test_asynccollector()
|