146 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			146 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import time
 | |
| import tqdm
 | |
| from torch.utils.tensorboard import SummaryWriter
 | |
| from typing import Dict, List, Union, Callable, Optional
 | |
| 
 | |
| from tianshou.data import Collector
 | |
| from tianshou.policy import BasePolicy
 | |
| from tianshou.utils import tqdm_config, MovAvg
 | |
| from tianshou.trainer import test_episode, gather_info
 | |
| 
 | |
| 
 | |
| def onpolicy_trainer(
 | |
|         policy: BasePolicy,
 | |
|         train_collector: Collector,
 | |
|         test_collector: Collector,
 | |
|         max_epoch: int,
 | |
|         step_per_epoch: int,
 | |
|         collect_per_step: int,
 | |
|         repeat_per_collect: int,
 | |
|         episode_per_test: Union[int, List[int]],
 | |
|         batch_size: int,
 | |
|         train_fn: Optional[Callable[[int], None]] = None,
 | |
|         test_fn: Optional[Callable[[int], None]] = None,
 | |
|         stop_fn: Optional[Callable[[float], bool]] = None,
 | |
|         save_fn: Optional[Callable[[BasePolicy], None]] = None,
 | |
|         log_fn: Optional[Callable[[dict], None]] = None,
 | |
|         writer: Optional[SummaryWriter] = None,
 | |
|         log_interval: int = 1,
 | |
|         verbose: bool = True,
 | |
|         **kwargs
 | |
| ) -> Dict[str, Union[float, str]]:
 | |
|     """A wrapper for on-policy trainer procedure.
 | |
| 
 | |
|     :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
 | |
|         class.
 | |
|     :param train_collector: the collector used for training.
 | |
|     :type train_collector: :class:`~tianshou.data.Collector`
 | |
|     :param test_collector: the collector used for testing.
 | |
|     :type test_collector: :class:`~tianshou.data.Collector`
 | |
|     :param int max_epoch: the maximum of epochs for training. The training
 | |
|         process might be finished before reaching the ``max_epoch``.
 | |
|     :param int step_per_epoch: the number of step for updating policy network
 | |
|         in one epoch.
 | |
|     :param int collect_per_step: the number of frames the collector would
 | |
|         collect before the network update. In other words, collect some frames
 | |
|         and do one policy network update.
 | |
|     :param int repeat_per_collect: the number of repeat time for policy
 | |
|         learning, for example, set it to 2 means the policy needs to learn each
 | |
|         given batch data twice.
 | |
|     :param episode_per_test: the number of episodes for one policy evaluation.
 | |
|     :type episode_per_test: int or list of ints
 | |
|     :param int batch_size: the batch size of sample data, which is going to
 | |
|         feed in the policy network.
 | |
|     :param function train_fn: a function receives the current number of epoch
 | |
|         index and performs some operations at the beginning of training in this
 | |
|         epoch.
 | |
|     :param function test_fn: a function receives the current number of epoch
 | |
|         index and performs some operations at the beginning of testing in this
 | |
|         epoch.
 | |
|     :param function save_fn: a function for saving policy when the undiscounted
 | |
|         average mean reward in evaluation phase gets better.
 | |
|     :param function stop_fn: a function receives the average undiscounted
 | |
|         returns of the testing result, return a boolean which indicates whether
 | |
|         reaching the goal.
 | |
|     :param function log_fn: a function receives env info for logging.
 | |
|     :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
 | |
|         SummaryWriter.
 | |
|     :param int log_interval: the log interval of the writer.
 | |
|     :param bool verbose: whether to print the information.
 | |
| 
 | |
|     :return: See :func:`~tianshou.trainer.gather_info`.
 | |
|     """
 | |
|     global_step = 0
 | |
|     best_epoch, best_reward = -1, -1
 | |
|     stat = {}
 | |
|     start_time = time.time()
 | |
|     test_in_train = train_collector.policy == policy
 | |
|     for epoch in range(1, 1 + max_epoch):
 | |
|         # train
 | |
|         policy.train()
 | |
|         if train_fn:
 | |
|             train_fn(epoch)
 | |
|         with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
 | |
|                        **tqdm_config) as t:
 | |
|             while t.n < t.total:
 | |
|                 result = train_collector.collect(n_episode=collect_per_step,
 | |
|                                                  log_fn=log_fn)
 | |
|                 data = {}
 | |
|                 if test_in_train and stop_fn and stop_fn(result['rew']):
 | |
|                     test_result = test_episode(
 | |
|                         policy, test_collector, test_fn,
 | |
|                         epoch, episode_per_test)
 | |
|                     if stop_fn and stop_fn(test_result['rew']):
 | |
|                         if save_fn:
 | |
|                             save_fn(policy)
 | |
|                         for k in result.keys():
 | |
|                             data[k] = f'{result[k]:.2f}'
 | |
|                         t.set_postfix(**data)
 | |
|                         return gather_info(
 | |
|                             start_time, train_collector, test_collector,
 | |
|                             test_result['rew'])
 | |
|                     else:
 | |
|                         policy.train()
 | |
|                         if train_fn:
 | |
|                             train_fn(epoch)
 | |
|                 losses = policy.learn(
 | |
|                     train_collector.sample(0), batch_size, repeat_per_collect)
 | |
|                 train_collector.reset_buffer()
 | |
|                 step = 1
 | |
|                 for k in losses.keys():
 | |
|                     if isinstance(losses[k], list):
 | |
|                         step = max(step, len(losses[k]))
 | |
|                 global_step += step
 | |
|                 for k in result.keys():
 | |
|                     data[k] = f'{result[k]:.2f}'
 | |
|                     if writer and global_step % log_interval == 0:
 | |
|                         writer.add_scalar(
 | |
|                             k, result[k], global_step=global_step)
 | |
|                 for k in losses.keys():
 | |
|                     if stat.get(k) is None:
 | |
|                         stat[k] = MovAvg()
 | |
|                     stat[k].add(losses[k])
 | |
|                     data[k] = f'{stat[k].get():.6f}'
 | |
|                     if writer and global_step % log_interval == 0:
 | |
|                         writer.add_scalar(
 | |
|                             k, stat[k].get(), global_step=global_step)
 | |
|                 t.update(step)
 | |
|                 t.set_postfix(**data)
 | |
|             if t.n <= t.total:
 | |
|                 t.update()
 | |
|         # test
 | |
|         result = test_episode(
 | |
|             policy, test_collector, test_fn, epoch, episode_per_test)
 | |
|         if best_epoch == -1 or best_reward < result['rew']:
 | |
|             best_reward = result['rew']
 | |
|             best_epoch = epoch
 | |
|             if save_fn:
 | |
|                 save_fn(policy)
 | |
|         if verbose:
 | |
|             print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
 | |
|                   f'best_reward: {best_reward:.6f} in #{best_epoch}')
 | |
|         if stop_fn and stop_fn(best_reward):
 | |
|             break
 | |
|     return gather_info(
 | |
|         start_time, train_collector, test_collector, best_reward)
 |