Tianshou/tianshou/trainer/onpolicy.py

120 lines
5.5 KiB
Python
Raw Normal View History

2020-03-19 17:23:46 +08:00
import time
import tqdm
from tianshou.utils import tqdm_config, MovAvg
2020-03-20 19:52:29 +08:00
from tianshou.trainer import test_episode, gather_info
2020-03-19 17:23:46 +08:00
2020-03-20 19:52:29 +08:00
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, repeat_per_collect,
episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None,
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
"""A wrapper for on-policy trainer procedure.
Parameters
* **policy** an instance of the :class:`~tianshou.policy.BasePolicy`\
class.
* **train_collector** the collector used for training.
* **test_collector** the collector used for testing.
* **max_epoch** the maximum of epochs for training. The training \
process might be finished before reaching the ``max_epoch``.
* **step_per_epoch** the number of step for updating policy network \
in one epoch.
* **collect_per_step** the number of frames the collector would \
collect before the network update. In other words, collect some \
frames and do one policy network update.
* **repeat_per_collect** the number of repeat time for policy \
learning, for example, set it to 2 means the policy needs to learn\
each given batch data twice.
* **episode_per_test** the number of episodes for one policy \
evaluation.
* **batch_size** the batch size of sample data, which is going to \
feed in the policy network.
* **train_fn** a function receives the current number of epoch index\
and performs some operations at the beginning of training in this \
epoch.
* **test_fn** a function receives the current number of epoch index \
and performs some operations at the beginning of testing in this \
epoch.
* **stop_fn** a function receives the average undiscounted returns \
of the testing result, return a boolean which indicates whether \
reaching the goal.
* **writer** a SummaryWriter provided from TensorBoard.
* **log_interval** an int indicating the log interval of the writer.
* **verbose** a boolean indicating whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
2020-03-19 17:23:46 +08:00
global_step = 0
best_epoch, best_reward = -1, -1
stat = {}
start_time = time.time()
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step)
2020-03-20 19:52:29 +08:00
data = {}
if stop_fn and stop_fn(result['rew']):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test)
if stop_fn and stop_fn(test_result['rew']):
for k in result.keys():
data[k] = f'{result[k]:.2f}'
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
losses = policy.learn(
train_collector.sample(0), batch_size, repeat_per_collect)
2020-03-19 17:23:46 +08:00
train_collector.reset_buffer()
step = 1
for k in losses.keys():
if isinstance(losses[k], list):
step = max(step, len(losses[k]))
global_step += step
for k in result.keys():
data[k] = f'{result[k]:.2f}'
if writer and global_step % log_interval == 0:
2020-03-19 17:23:46 +08:00
writer.add_scalar(
2020-03-28 09:43:35 +08:00
k + '_' + task if task else k,
result[k], global_step=global_step)
2020-03-19 17:23:46 +08:00
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f'{stat[k].get():.6f}'
if writer and global_step % log_interval == 0:
2020-03-19 17:23:46 +08:00
writer.add_scalar(
2020-03-28 09:43:35 +08:00
k + '_' + task if task else k,
stat[k].get(), global_step=global_step)
2020-03-19 17:23:46 +08:00
t.update(step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
2020-03-20 19:52:29 +08:00
# test
result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test)
2020-03-19 17:23:46 +08:00
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
best_epoch = epoch
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
2020-03-20 19:52:29 +08:00
if stop_fn and stop_fn(best_reward):
2020-03-19 17:23:46 +08:00
break
2020-03-20 19:52:29 +08:00
return gather_info(
start_time, train_collector, test_collector, best_reward)