Tianshou/tianshou/trainer/onpolicy.py
youkaichao 3a08e27ed4 Standardized behavior of Batch.cat and misc code refactor (#137)
* code refactor; remove unused kwargs; add reward_normalization for dqn

* bugfix for __setitem__ with torch.Tensor; add Batch.condense

* minor fix

* support cat with empty Batch

* remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases

* support stack with empty Batch

* remove condense

* refactor code to reflect the shared / partial / reserved categories of keys

* add is_empty(recursive=False)

* doc fix

* docfix and bugfix for _is_batch_set

* add doc for key reservation

* bugfix for algebra operators

* fix cat with lens hint

* code refactor

* bugfix for storing None

* use ValueError instead of exception

* hide lens away from users

* add comment for __cat

* move the computation of the initial value of lens in cat_ itself.

* change the place of doc string

* doc fix for Batch doc string

* change recursive to recurse

* doc string fix

* minor fix for batch doc
2020-07-20 15:54:18 +08:00

145 lines
6.4 KiB
Python

import time
import tqdm
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
collect_per_step: int,
repeat_per_collect: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
class.
:param train_collector: the collector used for training.
:type train_collector: :class:`~tianshou.data.Collector`
:param test_collector: the collector used for testing.
:type test_collector: :class:`~tianshou.data.Collector`
:param int max_epoch: the maximum of epochs for training. The training
process might be finished before reaching the ``max_epoch``.
:param int step_per_epoch: the number of step for updating policy network
in one epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do one policy network update.
:param int repeat_per_collect: the number of repeat time for policy
learning, for example, set it to 2 means the policy needs to learn each
given batch data twice.
:param episode_per_test: the number of episodes for one policy evaluation.
:type episode_per_test: int or list of ints
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
:param function test_fn: a function receives the current number of epoch
index and performs some operations at the beginning of testing in this
epoch.
:param function save_fn: a function for saving policy when the undiscounted
average mean reward in evaluation phase gets better.
:param function stop_fn: a function receives the average undiscounted
returns of the testing result, return a boolean which indicates whether
reaching the goal.
:param function log_fn: a function receives env info for logging.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
SummaryWriter.
:param int log_interval: the log interval of the writer.
:param bool verbose: whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
global_step = 0
best_epoch, best_reward = -1, -1
stat = {}
start_time = time.time()
test_in_train = train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step,
log_fn=log_fn)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test)
if stop_fn and stop_fn(test_result['rew']):
if save_fn:
save_fn(policy)
for k in result.keys():
data[k] = f'{result[k]:.2f}'
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
losses = policy.learn(
train_collector.sample(0), batch_size, repeat_per_collect)
train_collector.reset_buffer()
step = 1
for k in losses.keys():
if isinstance(losses[k], list):
step = max(step, len(losses[k]))
global_step += step
for k in result.keys():
data[k] = f'{result[k]:.2f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k, result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f'{stat[k].get():.6f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k, stat[k].get(), global_step=global_step)
t.update(step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# test
result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test)
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if stop_fn and stop_fn(best_reward):
break
return gather_info(
start_time, train_collector, test_collector, best_reward)