env info log_fn (#28)

This commit is contained in:
Trinkle23897 2020-04-10 18:02:05 +08:00
parent ecfcb9f295
commit 74407e13da
5 changed files with 32 additions and 8 deletions

View File

@ -27,4 +27,4 @@ class MyTestEnv(gym.Env):
elif action == 1:
self.index += 1
self.done = self.index == self.size
return self.index, int(self.done), self.done, {}
return self.index, int(self.done), self.done, {'key': 1}

View File

@ -1,4 +1,6 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import BasePolicy
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, Batch, ReplayBuffer
@ -26,21 +28,34 @@ def equal(a, b):
return abs(np.array(a) - np.array(b)).sum() < 1e-6
class Logger(object):
def __init__(self, writer):
self.cnt = 0
self.writer = writer
def log(self, info):
self.writer.add_scalar('key', info['key'], global_step=self.cnt)
self.cnt += 1
def test_collector():
writer = SummaryWriter('log/collector')
logger = Logger(writer)
env_fns = [
lambda: MyTestEnv(size=2, sleep=0),
lambda: MyTestEnv(size=3, sleep=0),
lambda: MyTestEnv(size=4, sleep=0),
lambda: MyTestEnv(size=5, sleep=0),
]
venv = SubprocVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
c0.collect(n_step=3)
c0.collect(n_step=3, log_fn=logger.log)
assert equal(c0.buffer.obs[:3], [0, 1, 0])
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
c0.collect(n_episode=3)
c0.collect(n_episode=3, log_fn=logger.log)
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))

View File

@ -169,7 +169,7 @@ class Collector(object):
isinstance(self.state, np.ndarray):
self.state[id] = 0
def collect(self, n_step=0, n_episode=0, render=None):
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
"""Collect a specified number of step or episode.
:param int n_step: how many steps you want to collect.
@ -178,6 +178,8 @@ class Collector(object):
:type n_episode: int or list
:param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering).
:param function log_fn: a function which receives env info, typically
for tensorboard logging.
.. note::
@ -232,6 +234,8 @@ class Collector(object):
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0])
if log_fn is not None:
log_fn(self._info)
if render is not None:
self.env.render()
if render > 0:

View File

@ -7,7 +7,8 @@ from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size, train_fn=None, test_fn=None, stop_fn=None,
batch_size,
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
"""A wrapper for off-policy trainer procedure.
@ -37,6 +38,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
:param int log_interval: the log interval of the writer.
@ -56,7 +58,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step)
result = train_collector.collect(n_step=collect_per_step,
log_fn=log_fn)
data = {}
if stop_fn and stop_fn(result['rew']):
test_result = test_episode(

View File

@ -8,7 +8,7 @@ from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, repeat_per_collect,
episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None,
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
"""A wrapper for on-policy trainer procedure.
@ -42,6 +42,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
:param int log_interval: the log interval of the writer.
@ -61,7 +62,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step)
result = train_collector.collect(n_episode=collect_per_step,
log_fn=log_fn)
data = {}
if stop_fn and stop_fn(result['rew']):
test_result = test_episode(