Tianshou/test/throughput/test_collector_profile.py
ChenDRAG 150d0ec51b
Step collector implementation (#280)
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-19 10:33:49 +08:00

110 lines
3.8 KiB
Python

import tqdm
import numpy as np
from tianshou.policy import BasePolicy
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.data import Batch, Collector, AsyncCollector, VectorReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
else: # pytest
from test.base.env import MyTestEnv
class MyPolicy(BasePolicy):
def __init__(self, dict_state=False, need_state=True):
"""
:param bool dict_state: if the observation of the environment is a dict
:param bool need_state: if the policy needs the hidden state (for RNN)
"""
super().__init__()
self.dict_state = dict_state
self.need_state = need_state
def forward(self, batch, state=None):
if self.need_state:
if state is None:
state = np.zeros((len(batch.obs), 2))
else:
state += 1
if self.dict_state:
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
return Batch(act=np.ones(len(batch.obs)), state=state)
def learn(self):
pass
def test_collector_nstep():
policy = MyPolicy()
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
c3 = Collector(policy, dum,
VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_step"):
c3.reset()
result = c3.collect(n_step=i * len(env_fns))
assert result['n/st'] >= i
def test_collector_nepisode():
policy = MyPolicy()
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
c3 = Collector(policy, dum,
VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_episode"):
c3.reset()
result = c3.collect(n_episode=i)
assert result['n/ep'] == i
assert result['n/st'] == len(c3.buffer)
def test_asynccollector():
env_lens = [2, 3, 4, 5]
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True)
for i in env_lens]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy()
bufsize = 300
c1 = AsyncCollector(
policy, venv,
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4))
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 100, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(
np.bincount(result["lens"], minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(
count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 150, desc="test async n_step"):
result = c1.collect(n_step=n_step)
assert result["n/st"] >= n_step
for i in range(4):
env_len = i + 2
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id == i)
assert np.all(buf.obs.reshape(-1, env_len) == seq)
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
if __name__ == '__main__':
test_collector_nstep()
test_collector_nepisode()
test_asynccollector()