env info log_fn (#28)
This commit is contained in:
parent
ecfcb9f295
commit
74407e13da
@ -27,4 +27,4 @@ class MyTestEnv(gym.Env):
|
||||
elif action == 1:
|
||||
self.index += 1
|
||||
self.done = self.index == self.size
|
||||
return self.index, int(self.done), self.done, {}
|
||||
return self.index, int(self.done), self.done, {'key': 1}
|
||||
|
@ -1,4 +1,6 @@
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.data import Collector, Batch, ReplayBuffer
|
||||
@ -26,21 +28,34 @@ def equal(a, b):
|
||||
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
||||
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, writer):
|
||||
self.cnt = 0
|
||||
self.writer = writer
|
||||
|
||||
def log(self, info):
|
||||
self.writer.add_scalar('key', info['key'], global_step=self.cnt)
|
||||
self.cnt += 1
|
||||
|
||||
|
||||
def test_collector():
|
||||
writer = SummaryWriter('log/collector')
|
||||
logger = Logger(writer)
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=2, sleep=0),
|
||||
lambda: MyTestEnv(size=3, sleep=0),
|
||||
lambda: MyTestEnv(size=4, sleep=0),
|
||||
lambda: MyTestEnv(size=5, sleep=0),
|
||||
]
|
||||
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
c0.collect(n_step=3)
|
||||
c0.collect(n_step=3, log_fn=logger.log)
|
||||
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
||||
c0.collect(n_episode=3)
|
||||
c0.collect(n_episode=3, log_fn=logger.log)
|
||||
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||
|
@ -169,7 +169,7 @@ class Collector(object):
|
||||
isinstance(self.state, np.ndarray):
|
||||
self.state[id] = 0
|
||||
|
||||
def collect(self, n_step=0, n_episode=0, render=None):
|
||||
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param int n_step: how many steps you want to collect.
|
||||
@ -178,6 +178,8 @@ class Collector(object):
|
||||
:type n_episode: int or list
|
||||
:param float render: the sleep time between rendering consecutive
|
||||
frames, defaults to ``None`` (no rendering).
|
||||
:param function log_fn: a function which receives env info, typically
|
||||
for tensorboard logging.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -232,6 +234,8 @@ class Collector(object):
|
||||
self._act = result.act
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||
self._act if self._multi_env else self._act[0])
|
||||
if log_fn is not None:
|
||||
log_fn(self._info)
|
||||
if render is not None:
|
||||
self.env.render()
|
||||
if render > 0:
|
||||
|
@ -7,7 +7,8 @@ from tianshou.trainer import test_episode, gather_info
|
||||
|
||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, episode_per_test,
|
||||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||
batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
||||
writer=None, log_interval=1, verbose=True, task='',
|
||||
**kwargs):
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
@ -37,6 +38,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
:param function log_fn: a function receives env info for logging.
|
||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||
SummaryWriter.
|
||||
:param int log_interval: the log interval of the writer.
|
||||
@ -56,7 +58,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = train_collector.collect(n_step=collect_per_step)
|
||||
result = train_collector.collect(n_step=collect_per_step,
|
||||
log_fn=log_fn)
|
||||
data = {}
|
||||
if stop_fn and stop_fn(result['rew']):
|
||||
test_result = test_episode(
|
||||
|
@ -8,7 +8,7 @@ from tianshou.trainer import test_episode, gather_info
|
||||
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, repeat_per_collect,
|
||||
episode_per_test, batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None,
|
||||
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
||||
writer=None, log_interval=1, verbose=True, task='',
|
||||
**kwargs):
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
@ -42,6 +42,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
:param function stop_fn: a function receives the average undiscounted
|
||||
returns of the testing result, return a boolean which indicates whether
|
||||
reaching the goal.
|
||||
:param function log_fn: a function receives env info for logging.
|
||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||
SummaryWriter.
|
||||
:param int log_interval: the log interval of the writer.
|
||||
@ -61,7 +62,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = train_collector.collect(n_episode=collect_per_step)
|
||||
result = train_collector.collect(n_episode=collect_per_step,
|
||||
log_fn=log_fn)
|
||||
data = {}
|
||||
if stop_fn and stop_fn(result['rew']):
|
||||
test_result = test_episode(
|
||||
|
Loading…
x
Reference in New Issue
Block a user