Tianshou/test/throughput/test_collector_profile.py

108 lines
3.8 KiB
Python
Raw Normal View History

import numpy as np
import tqdm
from tianshou.data import AsyncCollector, Batch, Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy
if __name__ == '__main__':
from env import MyTestEnv
else: # pytest
from test.base.env import MyTestEnv
class MyPolicy(BasePolicy):
def __init__(self, dict_state=False, need_state=True):
"""
:param bool dict_state: if the observation of the environment is a dict
:param bool need_state: if the policy needs the hidden state (for RNN)
"""
super().__init__()
self.dict_state = dict_state
self.need_state = need_state
def forward(self, batch, state=None):
if self.need_state:
if state is None:
state = np.zeros((len(batch.obs), 2))
else:
state += 1
if self.dict_state:
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
return Batch(act=np.ones(len(batch.obs)), state=state)
def learn(self):
pass
def test_collector_nstep():
policy = MyPolicy()
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_step"):
c3.reset()
result = c3.collect(n_step=i * len(env_fns))
assert result['n/st'] >= i
def test_collector_nepisode():
policy = MyPolicy()
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_episode"):
c3.reset()
result = c3.collect(n_episode=i)
assert result['n/ep'] == i
assert result['n/st'] == len(c3.buffer)
def test_asynccollector():
env_lens = [2, 3, 4, 5]
env_fns = [
lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens
]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy()
bufsize = 300
c1 = AsyncCollector(
policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)
)
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 100, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 150, desc="test async n_step"):
result = c1.collect(n_step=n_step)
assert result["n/st"] >= n_step
for i in range(4):
env_len = i + 2
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id == i)
assert np.all(buf.obs.reshape(-1, env_len) == seq)
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
if __name__ == '__main__':
test_collector_nstep()
test_collector_nepisode()
test_asynccollector()