2020-03-19 17:23:46 +08:00
|
|
|
import time
|
|
|
|
import tqdm
|
2020-05-12 11:31:47 +08:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from typing import Dict, List, Union, Callable, Optional
|
2020-03-19 17:23:46 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
from tianshou.data import Collector
|
|
|
|
from tianshou.policy import BasePolicy
|
2020-03-19 17:23:46 +08:00
|
|
|
from tianshou.utils import tqdm_config, MovAvg
|
2020-03-20 19:52:29 +08:00
|
|
|
from tianshou.trainer import test_episode, gather_info
|
2020-03-19 17:23:46 +08:00
|
|
|
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def onpolicy_trainer(
|
2020-09-12 15:39:01 +08:00
|
|
|
policy: BasePolicy,
|
|
|
|
train_collector: Collector,
|
|
|
|
test_collector: Collector,
|
|
|
|
max_epoch: int,
|
|
|
|
step_per_epoch: int,
|
|
|
|
collect_per_step: int,
|
|
|
|
repeat_per_collect: int,
|
|
|
|
episode_per_test: Union[int, List[int]],
|
|
|
|
batch_size: int,
|
2020-09-26 16:35:37 +08:00
|
|
|
train_fn: Optional[Callable[[int, int], None]] = None,
|
|
|
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
2020-09-12 15:39:01 +08:00
|
|
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
|
|
|
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
|
|
|
writer: Optional[SummaryWriter] = None,
|
|
|
|
log_interval: int = 1,
|
|
|
|
verbose: bool = True,
|
|
|
|
test_in_train: bool = True,
|
2020-05-12 11:31:47 +08:00
|
|
|
) -> Dict[str, Union[float, str]]:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""A wrapper for on-policy trainer procedure.
|
|
|
|
|
|
|
|
The "step" in trainer means a policy network update.
|
2020-04-05 18:34:45 +08:00
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
|
|
|
class.
|
|
|
|
:param train_collector: the collector used for training.
|
|
|
|
:type train_collector: :class:`~tianshou.data.Collector`
|
|
|
|
:param test_collector: the collector used for testing.
|
|
|
|
:type test_collector: :class:`~tianshou.data.Collector`
|
|
|
|
:param int max_epoch: the maximum of epochs for training. The training
|
|
|
|
process might be finished before reaching the ``max_epoch``.
|
|
|
|
:param int step_per_epoch: the number of step for updating policy network
|
|
|
|
in one epoch.
|
2020-07-27 16:54:14 +08:00
|
|
|
:param int collect_per_step: the number of episodes the collector would
|
|
|
|
collect before the network update. In other words, collect some
|
|
|
|
episodes and do one policy network update.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param int repeat_per_collect: the number of repeat time for policy
|
|
|
|
learning, for example, set it to 2 means the policy needs to learn each
|
|
|
|
given batch data twice.
|
|
|
|
:param episode_per_test: the number of episodes for one policy evaluation.
|
|
|
|
:type episode_per_test: int or list of ints
|
|
|
|
:param int batch_size: the batch size of sample data, which is going to
|
|
|
|
feed in the policy network.
|
|
|
|
:param function train_fn: a function receives the current number of epoch
|
2020-09-26 16:35:37 +08:00
|
|
|
and step index, and performs some operations at the beginning of
|
|
|
|
training in this poch.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param function test_fn: a function receives the current number of epoch
|
2020-09-26 16:35:37 +08:00
|
|
|
and step index, and performs some operations at the beginning of
|
|
|
|
testing in this epoch.
|
2020-04-11 16:54:27 +08:00
|
|
|
:param function save_fn: a function for saving policy when the undiscounted
|
|
|
|
average mean reward in evaluation phase gets better.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param function stop_fn: a function receives the average undiscounted
|
|
|
|
returns of the testing result, return a boolean which indicates whether
|
|
|
|
reaching the goal.
|
|
|
|
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
|
|
|
|
SummaryWriter.
|
|
|
|
:param int log_interval: the log interval of the writer.
|
|
|
|
:param bool verbose: whether to print the information.
|
2020-07-21 14:59:49 +08:00
|
|
|
:param bool test_in_train: whether to test in the training phase.
|
2020-04-05 18:34:45 +08:00
|
|
|
|
|
|
|
:return: See :func:`~tianshou.trainer.gather_info`.
|
|
|
|
"""
|
2020-03-19 17:23:46 +08:00
|
|
|
global_step = 0
|
2020-09-12 15:39:01 +08:00
|
|
|
best_epoch, best_reward = -1, -1.0
|
2020-09-13 19:31:50 +08:00
|
|
|
stat: Dict[str, MovAvg] = {}
|
2020-03-19 17:23:46 +08:00
|
|
|
start_time = time.time()
|
2020-09-22 16:28:46 +08:00
|
|
|
train_collector.reset_stat()
|
|
|
|
test_collector.reset_stat()
|
2020-07-21 14:59:49 +08:00
|
|
|
test_in_train = test_in_train and train_collector.policy == policy
|
2020-03-19 17:23:46 +08:00
|
|
|
for epoch in range(1, 1 + max_epoch):
|
2020-09-12 08:44:50 +08:00
|
|
|
# train
|
|
|
|
policy.train()
|
2020-09-12 15:39:01 +08:00
|
|
|
with tqdm.tqdm(
|
|
|
|
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
|
|
|
) as t:
|
2020-03-19 17:23:46 +08:00
|
|
|
while t.n < t.total:
|
2020-09-26 16:35:37 +08:00
|
|
|
if train_fn:
|
|
|
|
train_fn(epoch, global_step)
|
2020-07-23 16:40:53 +08:00
|
|
|
result = train_collector.collect(n_episode=collect_per_step)
|
2020-03-20 19:52:29 +08:00
|
|
|
data = {}
|
2020-09-12 15:39:01 +08:00
|
|
|
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
2020-03-20 19:52:29 +08:00
|
|
|
test_result = test_episode(
|
|
|
|
policy, test_collector, test_fn,
|
2020-08-27 12:15:18 +08:00
|
|
|
epoch, episode_per_test, writer, global_step)
|
2020-09-13 19:31:50 +08:00
|
|
|
if stop_fn(test_result["rew"]):
|
2020-04-11 16:54:27 +08:00
|
|
|
if save_fn:
|
|
|
|
save_fn(policy)
|
2020-03-20 19:52:29 +08:00
|
|
|
for k in result.keys():
|
2020-09-12 15:39:01 +08:00
|
|
|
data[k] = f"{result[k]:.2f}"
|
2020-03-20 19:52:29 +08:00
|
|
|
t.set_postfix(**data)
|
|
|
|
return gather_info(
|
|
|
|
start_time, train_collector, test_collector,
|
2020-09-12 15:39:01 +08:00
|
|
|
test_result["rew"])
|
2020-03-20 19:52:29 +08:00
|
|
|
else:
|
2020-09-12 08:44:50 +08:00
|
|
|
policy.train()
|
2020-08-15 16:10:42 +08:00
|
|
|
losses = policy.update(
|
2020-09-12 15:39:01 +08:00
|
|
|
0, train_collector.buffer,
|
|
|
|
batch_size=batch_size, repeat=repeat_per_collect)
|
2020-03-19 17:23:46 +08:00
|
|
|
train_collector.reset_buffer()
|
|
|
|
step = 1
|
2020-09-13 19:31:50 +08:00
|
|
|
for v in losses.values():
|
|
|
|
if isinstance(v, list):
|
|
|
|
step = max(step, len(v))
|
2020-08-27 12:15:18 +08:00
|
|
|
global_step += step * collect_per_step
|
2020-03-19 17:23:46 +08:00
|
|
|
for k in result.keys():
|
2020-09-12 15:39:01 +08:00
|
|
|
data[k] = f"{result[k]:.2f}"
|
2020-04-04 11:46:18 +08:00
|
|
|
if writer and global_step % log_interval == 0:
|
2020-03-19 17:23:46 +08:00
|
|
|
writer.add_scalar(
|
2020-09-12 15:39:01 +08:00
|
|
|
"train/" + k, result[k], global_step=global_step)
|
2020-03-19 17:23:46 +08:00
|
|
|
for k in losses.keys():
|
|
|
|
if stat.get(k) is None:
|
|
|
|
stat[k] = MovAvg()
|
|
|
|
stat[k].add(losses[k])
|
2020-09-12 15:39:01 +08:00
|
|
|
data[k] = f"{stat[k].get():.6f}"
|
2020-04-04 11:46:18 +08:00
|
|
|
if writer and global_step % log_interval == 0:
|
2020-03-19 17:23:46 +08:00
|
|
|
writer.add_scalar(
|
2020-05-05 13:39:51 +08:00
|
|
|
k, stat[k].get(), global_step=global_step)
|
2020-03-19 17:23:46 +08:00
|
|
|
t.update(step)
|
|
|
|
t.set_postfix(**data)
|
|
|
|
if t.n <= t.total:
|
|
|
|
t.update()
|
2020-03-20 19:52:29 +08:00
|
|
|
# test
|
2020-08-27 12:15:18 +08:00
|
|
|
result = test_episode(policy, test_collector, test_fn, epoch,
|
|
|
|
episode_per_test, writer, global_step)
|
2020-09-12 15:39:01 +08:00
|
|
|
if best_epoch == -1 or best_reward < result["rew"]:
|
|
|
|
best_reward = result["rew"]
|
2020-03-19 17:23:46 +08:00
|
|
|
best_epoch = epoch
|
2020-04-11 16:54:27 +08:00
|
|
|
if save_fn:
|
|
|
|
save_fn(policy)
|
2020-03-19 17:23:46 +08:00
|
|
|
if verbose:
|
2020-09-12 15:39:01 +08:00
|
|
|
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
|
|
|
|
f"best_reward: {best_reward:.6f} in #{best_epoch}")
|
2020-03-20 19:52:29 +08:00
|
|
|
if stop_fn and stop_fn(best_reward):
|
2020-03-19 17:23:46 +08:00
|
|
|
break
|
2020-03-20 19:52:29 +08:00
|
|
|
return gather_info(
|
|
|
|
start_time, train_collector, test_collector, best_reward)
|