2020-03-25 14:08:28 +08:00
|
|
|
import numpy as np
|
2020-04-10 18:02:05 +08:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
from tianshou.policy import BasePolicy
|
2020-04-28 20:56:02 +08:00
|
|
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
2020-04-10 09:01:17 +08:00
|
|
|
from tianshou.data import Collector, Batch, ReplayBuffer
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
from env import MyTestEnv
|
|
|
|
else: # pytest
|
|
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
|
|
|
|
|
|
class MyPolicy(BasePolicy):
|
2020-04-28 20:56:02 +08:00
|
|
|
def __init__(self, dict_state=False):
|
2020-03-25 14:08:28 +08:00
|
|
|
super().__init__()
|
2020-04-28 20:56:02 +08:00
|
|
|
self.dict_state = dict_state
|
2020-03-25 14:08:28 +08:00
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
def forward(self, batch, state=None):
|
2020-04-28 20:56:02 +08:00
|
|
|
if self.dict_state:
|
2020-06-23 16:50:59 +02:00
|
|
|
return Batch(act=np.ones(len(batch.obs['index'])))
|
|
|
|
return Batch(act=np.ones(len(batch.obs)))
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
def learn(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2020-05-05 13:39:51 +08:00
|
|
|
def preprocess_fn(**kwargs):
|
|
|
|
# modify info before adding into the buffer
|
2020-07-13 00:24:31 +08:00
|
|
|
# if info is not provided from env, it will be a ``Batch()``.
|
|
|
|
if not kwargs.get('info', Batch()).is_empty():
|
2020-05-05 13:39:51 +08:00
|
|
|
n = len(kwargs['obs'])
|
|
|
|
info = kwargs['info']
|
|
|
|
for i in range(n):
|
|
|
|
info[i].update(rew=kwargs['rew'][i])
|
|
|
|
return {'info': info}
|
2020-07-13 00:24:31 +08:00
|
|
|
# or: return Batch(info=info)
|
2020-05-05 13:39:51 +08:00
|
|
|
else:
|
2020-07-13 00:24:31 +08:00
|
|
|
return Batch()
|
2020-05-05 13:39:51 +08:00
|
|
|
|
|
|
|
|
2020-04-10 18:02:05 +08:00
|
|
|
class Logger(object):
|
|
|
|
def __init__(self, writer):
|
|
|
|
self.cnt = 0
|
|
|
|
self.writer = writer
|
|
|
|
|
|
|
|
def log(self, info):
|
2020-05-05 13:39:51 +08:00
|
|
|
self.writer.add_scalar(
|
|
|
|
'key', np.mean(info['key']), global_step=self.cnt)
|
2020-04-10 18:02:05 +08:00
|
|
|
self.cnt += 1
|
|
|
|
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
def test_collector():
|
2020-04-10 18:02:05 +08:00
|
|
|
writer = SummaryWriter('log/collector')
|
|
|
|
logger = Logger(writer)
|
2020-06-09 18:46:14 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
2020-04-10 18:02:05 +08:00
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
venv = SubprocVectorEnv(env_fns)
|
2020-06-11 08:57:37 +08:00
|
|
|
dum = VectorEnv(env_fns)
|
2020-03-25 14:08:28 +08:00
|
|
|
policy = MyPolicy()
|
|
|
|
env = env_fns[0]()
|
2020-05-05 13:39:51 +08:00
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
|
|
|
preprocess_fn)
|
2020-04-10 18:02:05 +08:00
|
|
|
c0.collect(n_step=3, log_fn=logger.log)
|
2020-06-08 21:53:00 +08:00
|
|
|
assert np.allclose(c0.buffer.obs[:3], [0, 1, 0])
|
|
|
|
assert np.allclose(c0.buffer[:3].obs_next, [1, 2, 1])
|
2020-04-10 18:02:05 +08:00
|
|
|
c0.collect(n_episode=3, log_fn=logger.log)
|
2020-06-08 21:53:00 +08:00
|
|
|
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
|
|
|
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
2020-06-11 08:57:37 +08:00
|
|
|
c0.collect(n_step=3, random=True)
|
2020-05-05 13:39:51 +08:00
|
|
|
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
|
|
|
preprocess_fn)
|
2020-03-25 14:08:28 +08:00
|
|
|
c1.collect(n_step=6)
|
2020-06-08 21:53:00 +08:00
|
|
|
assert np.allclose(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
|
|
|
assert np.allclose(c1.buffer[:11].obs_next,
|
|
|
|
[1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
2020-03-25 14:08:28 +08:00
|
|
|
c1.collect(n_episode=2)
|
2020-06-08 21:53:00 +08:00
|
|
|
assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
|
|
|
assert np.allclose(c1.buffer[11:21].obs_next,
|
|
|
|
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
2020-06-11 08:57:37 +08:00
|
|
|
c1.collect(n_episode=3, random=True)
|
|
|
|
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
|
2020-05-05 13:39:51 +08:00
|
|
|
preprocess_fn)
|
2020-03-25 14:08:28 +08:00
|
|
|
c2.collect(n_episode=[1, 2, 2, 2])
|
2020-06-08 21:53:00 +08:00
|
|
|
assert np.allclose(c2.buffer.obs_next[:26], [
|
2020-03-25 14:08:28 +08:00
|
|
|
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
|
|
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
|
|
|
c2.reset_env()
|
|
|
|
c2.collect(n_episode=[2, 2, 2, 2])
|
2020-06-08 21:53:00 +08:00
|
|
|
assert np.allclose(c2.buffer.obs_next[26:54], [
|
2020-03-25 14:08:28 +08:00
|
|
|
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
|
|
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
2020-06-11 08:57:37 +08:00
|
|
|
c2.collect(n_episode=[1, 1, 1, 1], random=True)
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
def test_collector_with_dict_state():
|
|
|
|
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
|
|
|
policy = MyPolicy(dict_state=True)
|
2020-05-05 13:39:51 +08:00
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
|
2020-04-28 20:56:02 +08:00
|
|
|
c0.collect(n_step=3)
|
|
|
|
c0.collect(n_episode=3)
|
2020-06-09 18:46:14 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
|
|
|
for i in [2, 3, 4, 5]]
|
2020-04-28 20:56:02 +08:00
|
|
|
envs = VectorEnv(env_fns)
|
2020-05-05 13:39:51 +08:00
|
|
|
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
|
2020-04-28 20:56:02 +08:00
|
|
|
c1.collect(n_step=10)
|
|
|
|
c1.collect(n_episode=[2, 1, 1, 2])
|
2020-04-29 12:14:53 +08:00
|
|
|
batch = c1.sample(10)
|
|
|
|
print(batch)
|
|
|
|
c0.buffer.update(c1.buffer)
|
2020-06-08 21:53:00 +08:00
|
|
|
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, [
|
2020-04-29 12:14:53 +08:00
|
|
|
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
|
|
|
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
|
|
|
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
2020-05-05 13:39:51 +08:00
|
|
|
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
|
|
|
preprocess_fn)
|
2020-04-29 12:14:53 +08:00
|
|
|
c2.collect(n_episode=[0, 0, 0, 10])
|
|
|
|
batch = c2.sample(10)
|
|
|
|
print(batch['obs_next']['index'])
|
2020-04-28 20:56:02 +08:00
|
|
|
|
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
def test_collector_with_ma():
|
|
|
|
def reward_metric(x):
|
|
|
|
return x.sum()
|
|
|
|
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
|
|
|
|
policy = MyPolicy()
|
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100),
|
|
|
|
preprocess_fn, reward_metric=reward_metric)
|
|
|
|
r = c0.collect(n_step=3)['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 0.
|
|
|
|
r = c0.collect(n_episode=3)['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
|
|
|
|
for i in [2, 3, 4, 5]]
|
|
|
|
envs = VectorEnv(env_fns)
|
|
|
|
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
|
|
|
preprocess_fn, reward_metric=reward_metric)
|
|
|
|
r = c1.collect(n_step=10)['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
|
|
|
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
|
|
|
batch = c1.sample(10)
|
|
|
|
print(batch)
|
|
|
|
c0.buffer.update(c1.buffer)
|
|
|
|
obs = [
|
|
|
|
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
|
|
|
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
|
|
|
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
|
|
|
|
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
|
|
|
|
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
|
|
|
|
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
|
|
|
|
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
|
|
|
|
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
|
|
|
|
[[x] * 4 for x in rew])
|
|
|
|
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
|
|
|
preprocess_fn, reward_metric=reward_metric)
|
|
|
|
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
|
|
|
batch = c2.sample(10)
|
|
|
|
print(batch['obs_next'])
|
|
|
|
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
test_collector()
|
2020-04-28 20:56:02 +08:00
|
|
|
test_collector_with_dict_state()
|
2020-07-13 00:24:31 +08:00
|
|
|
test_collector_with_ma()
|