2020-03-19 17:23:46 +08:00
|
|
|
import time
|
|
|
|
import tqdm
|
|
|
|
|
|
|
|
from tianshou.utils import tqdm_config, MovAvg
|
2020-03-20 19:52:29 +08:00
|
|
|
from tianshou.trainer import test_episode, gather_info
|
2020-03-19 17:23:46 +08:00
|
|
|
|
|
|
|
|
2020-03-20 19:52:29 +08:00
|
|
|
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|
|
|
step_per_epoch, collect_per_step, repeat_per_collect,
|
|
|
|
episode_per_test, batch_size,
|
|
|
|
train_fn=None, test_fn=None, stop_fn=None,
|
2020-03-28 07:27:18 +08:00
|
|
|
writer=None, verbose=True, task=''):
|
2020-03-19 17:23:46 +08:00
|
|
|
global_step = 0
|
|
|
|
best_epoch, best_reward = -1, -1
|
|
|
|
stat = {}
|
|
|
|
start_time = time.time()
|
|
|
|
for epoch in range(1, 1 + max_epoch):
|
|
|
|
# train
|
|
|
|
policy.train()
|
|
|
|
if train_fn:
|
|
|
|
train_fn(epoch)
|
|
|
|
with tqdm.tqdm(
|
|
|
|
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
|
|
|
**tqdm_config) as t:
|
|
|
|
while t.n < t.total:
|
|
|
|
result = train_collector.collect(n_episode=collect_per_step)
|
2020-03-20 19:52:29 +08:00
|
|
|
data = {}
|
|
|
|
if stop_fn and stop_fn(result['rew']):
|
|
|
|
test_result = test_episode(
|
|
|
|
policy, test_collector, test_fn,
|
|
|
|
epoch, episode_per_test)
|
|
|
|
if stop_fn and stop_fn(test_result['rew']):
|
|
|
|
for k in result.keys():
|
|
|
|
data[k] = f'{result[k]:.2f}'
|
|
|
|
t.set_postfix(**data)
|
|
|
|
return gather_info(
|
|
|
|
start_time, train_collector, test_collector,
|
|
|
|
test_result['rew'])
|
|
|
|
else:
|
|
|
|
policy.train()
|
|
|
|
if train_fn:
|
|
|
|
train_fn(epoch)
|
|
|
|
losses = policy.learn(
|
|
|
|
train_collector.sample(0), batch_size, repeat_per_collect)
|
2020-03-19 17:23:46 +08:00
|
|
|
train_collector.reset_buffer()
|
|
|
|
step = 1
|
|
|
|
for k in losses.keys():
|
|
|
|
if isinstance(losses[k], list):
|
|
|
|
step = max(step, len(losses[k]))
|
|
|
|
global_step += step
|
|
|
|
for k in result.keys():
|
|
|
|
data[k] = f'{result[k]:.2f}'
|
|
|
|
if writer:
|
|
|
|
writer.add_scalar(
|
2020-03-28 07:27:18 +08:00
|
|
|
k + '_' + task, result[k], global_step=global_step)
|
2020-03-19 17:23:46 +08:00
|
|
|
for k in losses.keys():
|
|
|
|
if stat.get(k) is None:
|
|
|
|
stat[k] = MovAvg()
|
|
|
|
stat[k].add(losses[k])
|
|
|
|
data[k] = f'{stat[k].get():.6f}'
|
2020-03-28 07:27:18 +08:00
|
|
|
if writer and global_step:
|
2020-03-19 17:23:46 +08:00
|
|
|
writer.add_scalar(
|
2020-03-28 07:27:18 +08:00
|
|
|
k + '_' + task, stat[k].get(), global_step=global_step)
|
2020-03-19 17:23:46 +08:00
|
|
|
t.update(step)
|
|
|
|
t.set_postfix(**data)
|
|
|
|
if t.n <= t.total:
|
|
|
|
t.update()
|
2020-03-20 19:52:29 +08:00
|
|
|
# test
|
|
|
|
result = test_episode(
|
|
|
|
policy, test_collector, test_fn, epoch, episode_per_test)
|
2020-03-19 17:23:46 +08:00
|
|
|
if best_epoch == -1 or best_reward < result['rew']:
|
|
|
|
best_reward = result['rew']
|
|
|
|
best_epoch = epoch
|
|
|
|
if verbose:
|
|
|
|
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
|
|
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
2020-03-20 19:52:29 +08:00
|
|
|
if stop_fn and stop_fn(best_reward):
|
2020-03-19 17:23:46 +08:00
|
|
|
break
|
2020-03-20 19:52:29 +08:00
|
|
|
return gather_info(
|
|
|
|
start_time, train_collector, test_collector, best_reward)
|