env info log_fn (#28)
This commit is contained in:
parent
ecfcb9f295
commit
74407e13da
@ -27,4 +27,4 @@ class MyTestEnv(gym.Env):
|
|||||||
elif action == 1:
|
elif action == 1:
|
||||||
self.index += 1
|
self.index += 1
|
||||||
self.done = self.index == self.size
|
self.done = self.index == self.size
|
||||||
return self.index, int(self.done), self.done, {}
|
return self.index, int(self.done), self.done, {'key': 1}
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.data import Collector, Batch, ReplayBuffer
|
from tianshou.data import Collector, Batch, ReplayBuffer
|
||||||
@ -26,21 +28,34 @@ def equal(a, b):
|
|||||||
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(object):
|
||||||
|
def __init__(self, writer):
|
||||||
|
self.cnt = 0
|
||||||
|
self.writer = writer
|
||||||
|
|
||||||
|
def log(self, info):
|
||||||
|
self.writer.add_scalar('key', info['key'], global_step=self.cnt)
|
||||||
|
self.cnt += 1
|
||||||
|
|
||||||
|
|
||||||
def test_collector():
|
def test_collector():
|
||||||
|
writer = SummaryWriter('log/collector')
|
||||||
|
logger = Logger(writer)
|
||||||
env_fns = [
|
env_fns = [
|
||||||
lambda: MyTestEnv(size=2, sleep=0),
|
lambda: MyTestEnv(size=2, sleep=0),
|
||||||
lambda: MyTestEnv(size=3, sleep=0),
|
lambda: MyTestEnv(size=3, sleep=0),
|
||||||
lambda: MyTestEnv(size=4, sleep=0),
|
lambda: MyTestEnv(size=4, sleep=0),
|
||||||
lambda: MyTestEnv(size=5, sleep=0),
|
lambda: MyTestEnv(size=5, sleep=0),
|
||||||
]
|
]
|
||||||
|
|
||||||
venv = SubprocVectorEnv(env_fns)
|
venv = SubprocVectorEnv(env_fns)
|
||||||
policy = MyPolicy()
|
policy = MyPolicy()
|
||||||
env = env_fns[0]()
|
env = env_fns[0]()
|
||||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
|
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||||
c0.collect(n_step=3)
|
c0.collect(n_step=3, log_fn=logger.log)
|
||||||
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||||
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
assert equal(c0.buffer[:3].obs_next, [1, 2, 1])
|
||||||
c0.collect(n_episode=3)
|
c0.collect(n_episode=3, log_fn=logger.log)
|
||||||
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||||
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
assert equal(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False))
|
||||||
|
@ -169,7 +169,7 @@ class Collector(object):
|
|||||||
isinstance(self.state, np.ndarray):
|
isinstance(self.state, np.ndarray):
|
||||||
self.state[id] = 0
|
self.state[id] = 0
|
||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, render=None):
|
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
|
||||||
"""Collect a specified number of step or episode.
|
"""Collect a specified number of step or episode.
|
||||||
|
|
||||||
:param int n_step: how many steps you want to collect.
|
:param int n_step: how many steps you want to collect.
|
||||||
@ -178,6 +178,8 @@ class Collector(object):
|
|||||||
:type n_episode: int or list
|
:type n_episode: int or list
|
||||||
:param float render: the sleep time between rendering consecutive
|
:param float render: the sleep time between rendering consecutive
|
||||||
frames, defaults to ``None`` (no rendering).
|
frames, defaults to ``None`` (no rendering).
|
||||||
|
:param function log_fn: a function which receives env info, typically
|
||||||
|
for tensorboard logging.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -232,6 +234,8 @@ class Collector(object):
|
|||||||
self._act = result.act
|
self._act = result.act
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
self._act if self._multi_env else self._act[0])
|
self._act if self._multi_env else self._act[0])
|
||||||
|
if log_fn is not None:
|
||||||
|
log_fn(self._info)
|
||||||
if render is not None:
|
if render is not None:
|
||||||
self.env.render()
|
self.env.render()
|
||||||
if render > 0:
|
if render > 0:
|
||||||
|
@ -7,7 +7,8 @@ from tianshou.trainer import test_episode, gather_info
|
|||||||
|
|
||||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
step_per_epoch, collect_per_step, episode_per_test,
|
step_per_epoch, collect_per_step, episode_per_test,
|
||||||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
batch_size,
|
||||||
|
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
||||||
writer=None, log_interval=1, verbose=True, task='',
|
writer=None, log_interval=1, verbose=True, task='',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""A wrapper for off-policy trainer procedure.
|
"""A wrapper for off-policy trainer procedure.
|
||||||
@ -37,6 +38,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
:param function stop_fn: a function receives the average undiscounted
|
:param function stop_fn: a function receives the average undiscounted
|
||||||
returns of the testing result, return a boolean which indicates whether
|
returns of the testing result, return a boolean which indicates whether
|
||||||
reaching the goal.
|
reaching the goal.
|
||||||
|
:param function log_fn: a function receives env info for logging.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||||
SummaryWriter.
|
SummaryWriter.
|
||||||
:param int log_interval: the log interval of the writer.
|
:param int log_interval: the log interval of the writer.
|
||||||
@ -56,7 +58,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
**tqdm_config) as t:
|
**tqdm_config) as t:
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
result = train_collector.collect(n_step=collect_per_step)
|
result = train_collector.collect(n_step=collect_per_step,
|
||||||
|
log_fn=log_fn)
|
||||||
data = {}
|
data = {}
|
||||||
if stop_fn and stop_fn(result['rew']):
|
if stop_fn and stop_fn(result['rew']):
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
|
@ -8,7 +8,7 @@ from tianshou.trainer import test_episode, gather_info
|
|||||||
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
step_per_epoch, collect_per_step, repeat_per_collect,
|
step_per_epoch, collect_per_step, repeat_per_collect,
|
||||||
episode_per_test, batch_size,
|
episode_per_test, batch_size,
|
||||||
train_fn=None, test_fn=None, stop_fn=None,
|
train_fn=None, test_fn=None, stop_fn=None, log_fn=None,
|
||||||
writer=None, log_interval=1, verbose=True, task='',
|
writer=None, log_interval=1, verbose=True, task='',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""A wrapper for on-policy trainer procedure.
|
"""A wrapper for on-policy trainer procedure.
|
||||||
@ -42,6 +42,7 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
:param function stop_fn: a function receives the average undiscounted
|
:param function stop_fn: a function receives the average undiscounted
|
||||||
returns of the testing result, return a boolean which indicates whether
|
returns of the testing result, return a boolean which indicates whether
|
||||||
reaching the goal.
|
reaching the goal.
|
||||||
|
:param function log_fn: a function receives env info for logging.
|
||||||
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
||||||
SummaryWriter.
|
SummaryWriter.
|
||||||
:param int log_interval: the log interval of the writer.
|
:param int log_interval: the log interval of the writer.
|
||||||
@ -61,7 +62,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
**tqdm_config) as t:
|
**tqdm_config) as t:
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
result = train_collector.collect(n_episode=collect_per_step)
|
result = train_collector.collect(n_episode=collect_per_step,
|
||||||
|
log_fn=log_fn)
|
||||||
data = {}
|
data = {}
|
||||||
if stop_fn and stop_fn(result['rew']):
|
if stop_fn and stop_fn(result['rew']):
|
||||||
test_result = test_episode(
|
test_result = test_episode(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user