2020-03-25 14:08:28 +08:00
|
|
|
import numpy as np
|
|
|
|
from tianshou.policy import BasePolicy
|
|
|
|
from tianshou.env import SubprocVectorEnv
|
2020-04-10 09:01:17 +08:00
|
|
|
from tianshou.data import Collector, Batch, ReplayBuffer
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
from env import MyTestEnv
|
|
|
|
else: # pytest
|
|
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
|
|
|
|
|
|
class MyPolicy(BasePolicy):
|
|
|
|
"""docstring for MyPolicy"""
|
2020-03-26 09:01:20 +08:00
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
def forward(self, batch, state=None):
|
2020-03-25 14:08:28 +08:00
|
|
|
return Batch(act=np.ones(batch.obs.shape[0]))
|
|
|
|
|
|
|
|
def learn(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def equal(a, b):
|
|
|
|
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
|
|
|
|
|
|
|
|
|
|
|
def test_collector():
|
|
|
|
env_fns = [
|
|
|
|
lambda: MyTestEnv(size=2, sleep=0),
|
|
|
|
lambda: MyTestEnv(size=3, sleep=0),
|
|
|
|
lambda: MyTestEnv(size=4, sleep=0),
|
|
|
|
lambda: MyTestEnv(size=5, sleep=0),
|
|
|
|
]
|
|
|
|
venv = SubprocVectorEnv(env_fns)
|
|
|
|
policy = MyPolicy()
|
|
|
|
env = env_fns[0]()
|
2020-04-10 09:01:17 +08:00
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
|
2020-03-25 14:08:28 +08:00
|
|
|
c0.collect(n_step=3)
|
|
|
|
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
2020-04-10 09:01:17 +08:00
|
|
|
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
2020-03-25 14:08:28 +08:00
|
|
|
c0.collect(n_episode=3)
|
|
|
|
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
2020-04-10 09:01:17 +08:00
|
|
|
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
|
|
|
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
2020-03-25 14:08:28 +08:00
|
|
|
c1.collect(n_step=6)
|
|
|
|
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
2020-04-10 09:01:17 +08:00
|
|
|
assert equal(c1.buffer[:11].obs_next, [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
2020-03-25 14:08:28 +08:00
|
|
|
c1.collect(n_episode=2)
|
|
|
|
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
2020-04-10 09:01:17 +08:00
|
|
|
assert equal(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
|
|
|
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
2020-03-25 14:08:28 +08:00
|
|
|
c2.collect(n_episode=[1, 2, 2, 2])
|
|
|
|
assert equal(c2.buffer.obs_next[:26], [
|
|
|
|
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
|
|
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
|
|
|
c2.reset_env()
|
|
|
|
c2.collect(n_episode=[2, 2, 2, 2])
|
|
|
|
assert equal(c2.buffer.obs_next[26:54], [
|
|
|
|
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
|
|
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_collector()
|