env info log_fn (#28)

This commit is contained in:
Trinkle23897 2020-04-10 18:02:05 +08:00
parent ecfcb9f295
commit 74407e13da
5 changed files with 32 additions and 8 deletions

View File

@ -27,4 +27,4 @@ class MyTestEnv(gym.Env):
elif action == 1: elif action == 1:
self.index += 1 self.index += 1
self.done = self.index == self.size self.done = self.index == self.size
return self.index, int(self.done), self.done, {} return self.index, int(self.done), self.done, {'key': 1}

View File

@ -1,4 +1,6 @@
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, Batch, ReplayBuffer from tianshou.data import Collector, Batch, ReplayBuffer
@ -26,21 +28,34 @@ def equal(a, b):
return abs(np.array(a) - np.array(b)).sum() < 1e-6 return abs(np.array(a) - np.array(b)).sum() < 1e-6
class Logger(object):
def __init__(self, writer):
self.cnt = 0
self.writer = writer
def log(self, info):
self.writer.add_scalar('key', info['key'], global_step=self.cnt)
self.cnt += 1
def test_collector(): def test_collector():
writer = SummaryWriter('log/collector')
logger = Logger(writer)
env_fns = [ env_fns = [
lambda: MyTestEnv(size=2, sleep=0), lambda: MyTestEnv(size=2, sleep=0),
lambda: MyTestEnv(size=3, sleep=0), lambda: MyTestEnv(size=3, sleep=0),
lambda: MyTestEnv(size=4, sleep=0), lambda: MyTestEnv(size=4, sleep=0),
lambda: MyTestEnv(size=5, sleep=0), lambda: MyTestEnv(size=5, sleep=0),
] ]
venv = SubprocVectorEnv(env_fns) venv = SubprocVectorEnv(env_fns)
policy = MyPolicy() policy = MyPolicy()
env = env_fns[0]() env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False)) c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
c0.collect(n_step=3) c0.collect(n_step=3, log_fn=logger.log)
assert equal(c0.buffer.obs[:3], [0, 1, 0]) assert equal(c0.buffer.obs[:3], [0, 1, 0])
assert equal(c0.buffer[:3].obs_next, [1, 2, 1]) assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
c0.collect(n_episode=3) c0.collect(n_episode=3, log_fn=logger.log)
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False)) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))

View File

@ -169,7 +169,7 @@ class Collector(object):
isinstance(self.state, np.ndarray): isinstance(self.state, np.ndarray):
self.state[id] = 0 self.state[id] = 0
def collect(self, n_step=0, n_episode=0, render=None): def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
"""Collect a specified number of step or episode. """Collect a specified number of step or episode.
:param int n_step: how many steps you want to collect. :param int n_step: how many steps you want to collect.
@ -178,6 +178,8 @@ class Collector(object):
:type n_episode: int or list :type n_episode: int or list
:param float render: the sleep time between rendering consecutive :param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering). frames, defaults to ``None`` (no rendering).
:param function log_fn: a function which receives env info, typically
for tensorboard logging.
.. note:: .. note::
@ -232,6 +234,8 @@ class Collector(object):
self._act = result.act self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step( obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0]) self._act if self._multi_env else self._act[0])
if log_fn is not None:
log_fn(self._info)
if render is not None: if render is not None:
self.env.render() self.env.render()
if render > 0: if render > 0:

View File

@ -7,7 +7,8 @@ from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test, step_per_epoch, collect_per_step, episode_per_test,
batch_size, train_fn=None, test_fn=None, stop_fn=None, batch_size,
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
writer=None, log_interval=1, verbose=True, task='', writer=None, log_interval=1, verbose=True, task='',
**kwargs): **kwargs):
"""A wrapper for off-policy trainer procedure. """A wrapper for off-policy trainer procedure.
@ -37,6 +38,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function stop_fn: a function receives the average undiscounted :param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether returns of the testing result, return a boolean which indicates whether
reaching the goal. reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter. SummaryWriter.
:param int log_interval: the log interval of the writer. :param int log_interval: the log interval of the writer.
@ -56,7 +58,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t: **tqdm_config) as t:
while t.n < t.total: while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step) result = train_collector.collect(n_step=collect_per_step,
log_fn=log_fn)
data = {} data = {}
if stop_fn and stop_fn(result['rew']): if stop_fn and stop_fn(result['rew']):
test_result = test_episode( test_result = test_episode(

View File

@ -8,7 +8,7 @@ from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, repeat_per_collect, step_per_epoch, collect_per_step, repeat_per_collect,
episode_per_test, batch_size, episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None, train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
writer=None, log_interval=1, verbose=True, task='', writer=None, log_interval=1, verbose=True, task='',
**kwargs): **kwargs):
"""A wrapper for on-policy trainer procedure. """A wrapper for on-policy trainer procedure.
@ -42,6 +42,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function stop_fn: a function receives the average undiscounted :param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether returns of the testing result, return a boolean which indicates whether
reaching the goal. reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter. SummaryWriter.
:param int log_interval: the log interval of the writer. :param int log_interval: the log interval of the writer.
@ -61,7 +62,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t: **tqdm_config) as t:
while t.n < t.total: while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step) result = train_collector.collect(n_episode=collect_per_step,
log_fn=log_fn)
data = {} data = {}
if stop_fn and stop_fn(result['rew']): if stop_fn and stop_fn(result['rew']):
test_result = test_episode( test_result = test_episode(