Tianshou/tianshou/trainer/offpolicy.py
2020-04-05 18:34:45 +08:00

112 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import tqdm
from tianshou.utils import tqdm_config, MovAvg
from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size, train_fn=None, test_fn=None, stop_fn=None,
writer=None, log_interval=1, verbose=True, task='',
**kwargs):
"""A wrapper for off-policy trainer procedure.
Parameters
* **policy** an instance of the :class:`~tianshou.policy.BasePolicy`\
class.
* **train_collector** the collector used for training.
* **test_collector** the collector used for testing.
* **max_epoch** the maximum of epochs for training. The training \
process might be finished before reaching the ``max_epoch``.
* **step_per_epoch** the number of step for updating policy network \
in one epoch.
* **collect_per_step** the number of frames the collector would \
collect before the network update. In other words, collect some \
frames and do one policy network update.
* **episode_per_test** the number of episodes for one policy \
evaluation.
* **batch_size** the batch size of sample data, which is going to \
feed in the policy network.
* **train_fn** a function receives the current number of epoch index\
and performs some operations at the beginning of training in this \
epoch.
* **test_fn** a function receives the current number of epoch index \
and performs some operations at the beginning of testing in this \
epoch.
* **stop_fn** a function receives the average undiscounted returns \
of the testing result, return a boolean which indicates whether \
reaching the goal.
* **writer** a SummaryWriter provided from TensorBoard.
* **log_interval** an int indicating the log interval of the writer.
* **verbose** a boolean indicating whether to print the information.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
global_step = 0
best_epoch, best_reward = -1, -1
stat = {}
start_time = time.time()
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step)
data = {}
if stop_fn and stop_fn(result['rew']):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test)
if stop_fn and stop_fn(test_result['rew']):
for k in result.keys():
data[k] = f'{result[k]:.2f}'
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
for i in range(min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += 1
losses = policy.learn(train_collector.sample(batch_size))
for k in result.keys():
data[k] = f'{result[k]:.2f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k + '_' + task if task else k,
result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f'{stat[k].get():.6f}'
if writer and global_step % log_interval == 0:
writer.add_scalar(
k + '_' + task if task else k,
stat[k].get(), global_step=global_step)
t.update(1)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# test
result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test)
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
best_epoch = epoch
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if stop_fn and stop_fn(best_reward):
break
return gather_info(
start_time, train_collector, test_collector, best_reward)