112 lines
5.2 KiB
Python
112 lines
5.2 KiB
Python
import time
|
||
import tqdm
|
||
|
||
from tianshou.utils import tqdm_config, MovAvg
|
||
from tianshou.trainer import test_episode, gather_info
|
||
|
||
|
||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||
step_per_epoch, collect_per_step, episode_per_test,
|
||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||
writer=None, log_interval=1, verbose=True, task='',
|
||
**kwargs):
|
||
"""A wrapper for off-policy trainer procedure.
|
||
|
||
Parameters
|
||
* **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
|
||
class.
|
||
* **train_collector** – the collector used for training.
|
||
* **test_collector** – the collector used for testing.
|
||
* **max_epoch** – the maximum of epochs for training. The training \
|
||
process might be finished before reaching the ``max_epoch``.
|
||
* **step_per_epoch** – the number of step for updating policy network \
|
||
in one epoch.
|
||
* **collect_per_step** – the number of frames the collector would \
|
||
collect before the network update. In other words, collect some \
|
||
frames and do one policy network update.
|
||
* **episode_per_test** – the number of episodes for one policy \
|
||
evaluation.
|
||
* **batch_size** – the batch size of sample data, which is going to \
|
||
feed in the policy network.
|
||
* **train_fn** – a function receives the current number of epoch index\
|
||
and performs some operations at the beginning of training in this \
|
||
epoch.
|
||
* **test_fn** – a function receives the current number of epoch index \
|
||
and performs some operations at the beginning of testing in this \
|
||
epoch.
|
||
* **stop_fn** – a function receives the average undiscounted returns \
|
||
of the testing result, return a boolean which indicates whether \
|
||
reaching the goal.
|
||
* **writer** – a SummaryWriter provided from TensorBoard.
|
||
* **log_interval** – an int indicating the log interval of the writer.
|
||
* **verbose** – a boolean indicating whether to print the information.
|
||
|
||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||
"""
|
||
global_step = 0
|
||
best_epoch, best_reward = -1, -1
|
||
stat = {}
|
||
start_time = time.time()
|
||
for epoch in range(1, 1 + max_epoch):
|
||
# train
|
||
policy.train()
|
||
if train_fn:
|
||
train_fn(epoch)
|
||
with tqdm.tqdm(
|
||
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||
**tqdm_config) as t:
|
||
while t.n < t.total:
|
||
result = train_collector.collect(n_step=collect_per_step)
|
||
data = {}
|
||
if stop_fn and stop_fn(result['rew']):
|
||
test_result = test_episode(
|
||
policy, test_collector, test_fn,
|
||
epoch, episode_per_test)
|
||
if stop_fn and stop_fn(test_result['rew']):
|
||
for k in result.keys():
|
||
data[k] = f'{result[k]:.2f}'
|
||
t.set_postfix(**data)
|
||
return gather_info(
|
||
start_time, train_collector, test_collector,
|
||
test_result['rew'])
|
||
else:
|
||
policy.train()
|
||
if train_fn:
|
||
train_fn(epoch)
|
||
for i in range(min(
|
||
result['n/st'] // collect_per_step, t.total - t.n)):
|
||
global_step += 1
|
||
losses = policy.learn(train_collector.sample(batch_size))
|
||
for k in result.keys():
|
||
data[k] = f'{result[k]:.2f}'
|
||
if writer and global_step % log_interval == 0:
|
||
writer.add_scalar(
|
||
k + '_' + task if task else k,
|
||
result[k], global_step=global_step)
|
||
for k in losses.keys():
|
||
if stat.get(k) is None:
|
||
stat[k] = MovAvg()
|
||
stat[k].add(losses[k])
|
||
data[k] = f'{stat[k].get():.6f}'
|
||
if writer and global_step % log_interval == 0:
|
||
writer.add_scalar(
|
||
k + '_' + task if task else k,
|
||
stat[k].get(), global_step=global_step)
|
||
t.update(1)
|
||
t.set_postfix(**data)
|
||
if t.n <= t.total:
|
||
t.update()
|
||
# test
|
||
result = test_episode(
|
||
policy, test_collector, test_fn, epoch, episode_per_test)
|
||
if best_epoch == -1 or best_reward < result['rew']:
|
||
best_reward = result['rew']
|
||
best_epoch = epoch
|
||
if verbose:
|
||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||
if stop_fn and stop_fn(best_reward):
|
||
break
|
||
return gather_info(
|
||
start_time, train_collector, test_collector, best_reward)
|