Tianshou/test/throughput/test_collector_profile.py
n+e fc251ab0b8
bump to v0.4.3 (#432)
* add makefile
* bump version
* add isort and yapf
* update contributing.md
* update PR template
* spelling check
2021-09-03 05:05:04 +08:00

108 lines
3.8 KiB
Python

import numpy as np
import tqdm
from tianshou.data import AsyncCollector, Batch, Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy
if __name__ == '__main__':
from env import MyTestEnv
else: # pytest
from test.base.env import MyTestEnv
class MyPolicy(BasePolicy):
def __init__(self, dict_state=False, need_state=True):
"""
:param bool dict_state: if the observation of the environment is a dict
:param bool need_state: if the policy needs the hidden state (for RNN)
"""
super().__init__()
self.dict_state = dict_state
self.need_state = need_state
def forward(self, batch, state=None):
if self.need_state:
if state is None:
state = np.zeros((len(batch.obs), 2))
else:
state += 1
if self.dict_state:
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
return Batch(act=np.ones(len(batch.obs)), state=state)
def learn(self):
pass
def test_collector_nstep():
policy = MyPolicy()
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_step"):
c3.reset()
result = c3.collect(n_step=i * len(env_fns))
assert result['n/st'] >= i
def test_collector_nepisode():
policy = MyPolicy()
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_episode"):
c3.reset()
result = c3.collect(n_episode=i)
assert result['n/ep'] == i
assert result['n/st'] == len(c3.buffer)
def test_asynccollector():
env_lens = [2, 3, 4, 5]
env_fns = [
lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens
]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy()
bufsize = 300
c1 = AsyncCollector(
policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)
)
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 100, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 150, desc="test async n_step"):
result = c1.collect(n_step=n_step)
assert result["n/st"] >= n_step
for i in range(4):
env_len = i + 2
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id == i)
assert np.all(buf.obs.reshape(-1, env_len) == seq)
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
if __name__ == '__main__':
test_collector_nstep()
test_collector_nepisode()
test_asynccollector()