2020-03-25 14:08:28 +08:00
|
|
|
import numpy as np
|
2020-04-10 18:02:05 +08:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
from tianshou.policy import BasePolicy
|
2020-08-19 15:00:24 +08:00
|
|
|
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
2020-04-10 09:01:17 +08:00
|
|
|
from tianshou.data import Collector, Batch, ReplayBuffer
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
from env import MyTestEnv
|
|
|
|
else: # pytest
|
|
|
|
from test.base.env import MyTestEnv
|
|
|
|
|
|
|
|
|
|
|
|
class MyPolicy(BasePolicy):
|
2020-08-19 15:00:24 +08:00
|
|
|
def __init__(self, dict_state: bool = False, need_state: bool = True):
|
|
|
|
"""
|
|
|
|
:param bool dict_state: if the observation of the environment is a dict
|
|
|
|
:param bool need_state: if the policy needs the hidden state (for RNN)
|
|
|
|
"""
|
2020-03-25 14:08:28 +08:00
|
|
|
super().__init__()
|
2020-04-28 20:56:02 +08:00
|
|
|
self.dict_state = dict_state
|
2020-08-19 15:00:24 +08:00
|
|
|
self.need_state = need_state
|
2020-03-25 14:08:28 +08:00
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
def forward(self, batch, state=None):
|
2020-08-19 15:00:24 +08:00
|
|
|
if self.need_state:
|
|
|
|
if state is None:
|
|
|
|
state = np.zeros((len(batch.obs), 2))
|
|
|
|
else:
|
|
|
|
state += 1
|
2020-04-28 20:56:02 +08:00
|
|
|
if self.dict_state:
|
2020-08-19 15:00:24 +08:00
|
|
|
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
|
|
|
|
return Batch(act=np.ones(len(batch.obs)), state=state)
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
def learn(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2020-07-23 16:40:53 +08:00
|
|
|
class Logger:
|
2020-04-10 18:02:05 +08:00
|
|
|
def __init__(self, writer):
|
|
|
|
self.cnt = 0
|
|
|
|
self.writer = writer
|
|
|
|
|
2020-07-23 16:40:53 +08:00
|
|
|
def preprocess_fn(self, **kwargs):
|
|
|
|
# modify info before adding into the buffer, and recorded into tfb
|
2020-08-27 12:15:18 +08:00
|
|
|
# if only obs exist -> reset
|
|
|
|
# if obs/act/rew/done/... exist -> normal step
|
|
|
|
if 'rew' in kwargs:
|
2020-07-23 16:40:53 +08:00
|
|
|
n = len(kwargs['obs'])
|
|
|
|
info = kwargs['info']
|
|
|
|
for i in range(n):
|
|
|
|
info[i].update(rew=kwargs['rew'][i])
|
2020-08-27 12:15:18 +08:00
|
|
|
if 'key' in info.keys():
|
|
|
|
self.writer.add_scalar('key', np.mean(
|
|
|
|
info['key']), global_step=self.cnt)
|
2020-07-23 16:40:53 +08:00
|
|
|
self.cnt += 1
|
|
|
|
return Batch(info=info)
|
|
|
|
else:
|
|
|
|
return Batch()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def single_preprocess_fn(**kwargs):
|
|
|
|
# same as above, without tfb
|
2020-08-27 12:15:18 +08:00
|
|
|
if 'rew' in kwargs:
|
2020-07-23 16:40:53 +08:00
|
|
|
n = len(kwargs['obs'])
|
|
|
|
info = kwargs['info']
|
|
|
|
for i in range(n):
|
|
|
|
info[i].update(rew=kwargs['rew'][i])
|
|
|
|
return Batch(info=info)
|
|
|
|
else:
|
|
|
|
return Batch()
|
2020-04-10 18:02:05 +08:00
|
|
|
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
def test_collector():
|
2020-04-10 18:02:05 +08:00
|
|
|
writer = SummaryWriter('log/collector')
|
|
|
|
logger = Logger(writer)
|
2020-06-09 18:46:14 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
2020-04-10 18:02:05 +08:00
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
venv = SubprocVectorEnv(env_fns)
|
2020-08-19 15:00:24 +08:00
|
|
|
dum = DummyVectorEnv(env_fns)
|
2020-03-25 14:08:28 +08:00
|
|
|
policy = MyPolicy()
|
|
|
|
env = env_fns[0]()
|
2020-05-05 13:39:51 +08:00
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
2020-07-23 16:40:53 +08:00
|
|
|
logger.preprocess_fn)
|
|
|
|
c0.collect(n_step=3)
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 1])
|
|
|
|
assert np.allclose(c0.buffer[:4].obs_next[..., 0], [1, 2, 1, 2])
|
2020-07-23 16:40:53 +08:00
|
|
|
c0.collect(n_episode=3)
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
|
|
|
|
assert np.allclose(c0.buffer[:10].obs_next[..., 0],
|
|
|
|
[1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
|
2020-06-11 08:57:37 +08:00
|
|
|
c0.collect(n_step=3, random=True)
|
2020-05-05 13:39:51 +08:00
|
|
|
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
2020-07-23 16:40:53 +08:00
|
|
|
logger.preprocess_fn)
|
2020-03-25 14:08:28 +08:00
|
|
|
c1.collect(n_step=6)
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c1.buffer.obs[:11, 0],
|
|
|
|
[0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
|
|
|
assert np.allclose(c1.buffer[:11].obs_next[..., 0],
|
|
|
|
[1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
2020-03-25 14:08:28 +08:00
|
|
|
c1.collect(n_episode=2)
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c1.buffer.obs[11:21, 0], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
|
|
|
assert np.allclose(c1.buffer[11:21].obs_next[..., 0],
|
|
|
|
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
2020-06-11 08:57:37 +08:00
|
|
|
c1.collect(n_episode=3, random=True)
|
|
|
|
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
|
2020-07-23 16:40:53 +08:00
|
|
|
logger.preprocess_fn)
|
2020-03-25 14:08:28 +08:00
|
|
|
c2.collect(n_episode=[1, 2, 2, 2])
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c2.buffer.obs_next[:26, 0], [
|
2020-03-25 14:08:28 +08:00
|
|
|
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
2020-08-27 12:15:18 +08:00
|
|
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
2020-03-25 14:08:28 +08:00
|
|
|
c2.reset_env()
|
|
|
|
c2.collect(n_episode=[2, 2, 2, 2])
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c2.buffer.obs_next[26:54, 0], [
|
2020-03-25 14:08:28 +08:00
|
|
|
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
2020-08-27 12:15:18 +08:00
|
|
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
2020-06-11 08:57:37 +08:00
|
|
|
c2.collect(n_episode=[1, 1, 1, 1], random=True)
|
2020-03-25 14:08:28 +08:00
|
|
|
|
|
|
|
|
2020-08-27 12:15:18 +08:00
|
|
|
def test_collector_with_exact_episodes():
|
|
|
|
env_lens = [2, 6, 3, 10]
|
|
|
|
writer = SummaryWriter('log/exact_collector')
|
|
|
|
logger = Logger(writer)
|
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
|
|
|
|
for i in env_lens]
|
|
|
|
|
|
|
|
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
|
|
|
|
policy = MyPolicy()
|
|
|
|
c1 = Collector(policy, venv,
|
|
|
|
ReplayBuffer(size=1000, ignore_obs_next=False),
|
|
|
|
logger.preprocess_fn)
|
|
|
|
n_episode1 = [2, 0, 5, 1]
|
|
|
|
n_episode2 = [1, 3, 2, 0]
|
|
|
|
c1.collect(n_episode=n_episode1)
|
|
|
|
expected_steps = sum([a * b for a, b in zip(env_lens, n_episode1)])
|
|
|
|
actual_steps = sum(venv.steps)
|
|
|
|
assert expected_steps == actual_steps
|
|
|
|
c1.collect(n_episode=n_episode2)
|
|
|
|
expected_steps = sum(
|
|
|
|
[a * (b + c) for a, b, c in zip(env_lens, n_episode1, n_episode2)])
|
|
|
|
actual_steps = sum(venv.steps)
|
|
|
|
assert expected_steps == actual_steps
|
|
|
|
|
|
|
|
|
2020-07-26 12:01:21 +02:00
|
|
|
def test_collector_with_async():
|
|
|
|
env_lens = [2, 3, 4, 5]
|
|
|
|
writer = SummaryWriter('log/async_collector')
|
|
|
|
logger = Logger(writer)
|
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
|
|
|
|
for i in env_lens]
|
|
|
|
|
2020-08-19 15:00:24 +08:00
|
|
|
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
|
2020-07-26 12:01:21 +02:00
|
|
|
policy = MyPolicy()
|
|
|
|
c1 = Collector(policy, venv,
|
|
|
|
ReplayBuffer(size=1000, ignore_obs_next=False),
|
|
|
|
logger.preprocess_fn)
|
|
|
|
c1.collect(n_episode=10)
|
|
|
|
# check if the data in the buffer is chronological
|
|
|
|
# i.e. data in the buffer are full episodes, and each episode is
|
|
|
|
# returned by the same environment
|
|
|
|
env_id = c1.buffer.info['env_id']
|
|
|
|
size = len(c1.buffer)
|
|
|
|
obs = c1.buffer.obs[:size]
|
|
|
|
done = c1.buffer.done[:size]
|
|
|
|
obs_ground_truth = []
|
|
|
|
i = 0
|
|
|
|
while i < size:
|
|
|
|
# i is the start of an episode
|
|
|
|
if done[i]:
|
|
|
|
# this episode has one transition
|
|
|
|
assert env_lens[env_id[i]] == 1
|
|
|
|
i += 1
|
|
|
|
continue
|
|
|
|
j = i
|
|
|
|
while True:
|
|
|
|
j += 1
|
|
|
|
# in one episode, the environment id is the same
|
|
|
|
assert env_id[j] == env_id[i]
|
|
|
|
if done[j]:
|
|
|
|
break
|
|
|
|
j = j + 1 # j is the start of the next episode
|
|
|
|
assert j - i == env_lens[env_id[i]]
|
|
|
|
obs_ground_truth += list(range(j - i))
|
|
|
|
i = j
|
2020-08-04 13:39:05 +08:00
|
|
|
obs_ground_truth = np.expand_dims(
|
|
|
|
np.array(obs_ground_truth), axis=-1)
|
2020-07-26 12:01:21 +02:00
|
|
|
assert np.allclose(obs, obs_ground_truth)
|
|
|
|
|
|
|
|
|
2020-04-28 20:56:02 +08:00
|
|
|
def test_collector_with_dict_state():
|
|
|
|
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
|
|
|
policy = MyPolicy(dict_state=True)
|
2020-07-23 16:40:53 +08:00
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100),
|
|
|
|
Logger.single_preprocess_fn)
|
2020-04-28 20:56:02 +08:00
|
|
|
c0.collect(n_step=3)
|
2020-07-23 16:40:53 +08:00
|
|
|
c0.collect(n_episode=2)
|
2020-06-09 18:46:14 +08:00
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
|
|
|
for i in [2, 3, 4, 5]]
|
2020-08-19 15:00:24 +08:00
|
|
|
envs = DummyVectorEnv(env_fns)
|
2020-07-24 17:38:12 +08:00
|
|
|
envs.seed(666)
|
|
|
|
obs = envs.reset()
|
|
|
|
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
|
2020-07-23 16:40:53 +08:00
|
|
|
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
|
|
|
Logger.single_preprocess_fn)
|
2020-04-28 20:56:02 +08:00
|
|
|
c1.collect(n_step=10)
|
|
|
|
c1.collect(n_episode=[2, 1, 1, 2])
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c1.buffer.sample(10)
|
2020-04-29 12:14:53 +08:00
|
|
|
print(batch)
|
|
|
|
c0.buffer.update(c1.buffer)
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index[..., 0], [
|
2020-04-29 12:14:53 +08:00
|
|
|
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
|
|
|
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
2020-08-27 12:15:18 +08:00
|
|
|
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
2020-05-05 13:39:51 +08:00
|
|
|
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
2020-07-23 16:40:53 +08:00
|
|
|
Logger.single_preprocess_fn)
|
2020-04-29 12:14:53 +08:00
|
|
|
c2.collect(n_episode=[0, 0, 0, 10])
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c2.buffer.sample(10)
|
2020-04-28 20:56:02 +08:00
|
|
|
|
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
def test_collector_with_ma():
|
|
|
|
def reward_metric(x):
|
|
|
|
return x.sum()
|
|
|
|
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
|
|
|
|
policy = MyPolicy()
|
|
|
|
c0 = Collector(policy, env, ReplayBuffer(size=100),
|
2020-07-23 16:40:53 +08:00
|
|
|
Logger.single_preprocess_fn, reward_metric=reward_metric)
|
|
|
|
# n_step=3 will collect a full episode
|
2020-07-13 00:24:31 +08:00
|
|
|
r = c0.collect(n_step=3)['rew']
|
2020-07-23 16:40:53 +08:00
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
|
|
|
r = c0.collect(n_episode=2)['rew']
|
2020-07-13 00:24:31 +08:00
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
|
|
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
|
|
|
|
for i in [2, 3, 4, 5]]
|
2020-08-19 15:00:24 +08:00
|
|
|
envs = DummyVectorEnv(env_fns)
|
2020-07-13 00:24:31 +08:00
|
|
|
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
2020-07-23 16:40:53 +08:00
|
|
|
Logger.single_preprocess_fn, reward_metric=reward_metric)
|
2020-07-13 00:24:31 +08:00
|
|
|
r = c1.collect(n_step=10)['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
|
|
|
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c1.buffer.sample(10)
|
2020-07-13 00:24:31 +08:00
|
|
|
print(batch)
|
|
|
|
c0.buffer.update(c1.buffer)
|
2020-08-27 12:15:18 +08:00
|
|
|
assert np.allclose(c0.buffer[:len(c0.buffer)].obs[..., 0], [
|
2020-07-13 00:24:31 +08:00
|
|
|
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
|
|
|
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
2020-08-27 12:15:18 +08:00
|
|
|
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
|
2020-07-13 00:24:31 +08:00
|
|
|
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
|
|
|
|
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
|
|
|
|
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
|
|
|
|
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
|
|
|
|
[[x] * 4 for x in rew])
|
|
|
|
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
2020-07-23 16:40:53 +08:00
|
|
|
Logger.single_preprocess_fn, reward_metric=reward_metric)
|
2020-07-13 00:24:31 +08:00
|
|
|
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
|
|
|
|
assert np.asanyarray(r).size == 1 and r == 4.
|
2020-08-15 16:10:42 +08:00
|
|
|
batch, _ = c2.buffer.sample(10)
|
2020-07-13 00:24:31 +08:00
|
|
|
|
|
|
|
|
2020-03-25 14:08:28 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
test_collector()
|
2020-04-28 20:56:02 +08:00
|
|
|
test_collector_with_dict_state()
|
2020-07-13 00:24:31 +08:00
|
|
|
test_collector_with_ma()
|
2020-07-26 12:01:21 +02:00
|
|
|
test_collector_with_async()
|
2020-08-27 12:15:18 +08:00
|
|
|
test_collector_with_exact_episodes()
|