| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | import time | 
					
						
							|  |  |  | import tqdm | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | from torch.utils.tensorboard import SummaryWriter | 
					
						
							|  |  |  | from typing import Dict, List, Union, Callable, Optional | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | from tianshou.data import Collector | 
					
						
							|  |  |  | from tianshou.policy import BasePolicy | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | from tianshou.utils import tqdm_config, MovAvg | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | from tianshou.trainer import test_episode, gather_info | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | def onpolicy_trainer( | 
					
						
							|  |  |  |         policy: BasePolicy, | 
					
						
							|  |  |  |         train_collector: Collector, | 
					
						
							|  |  |  |         test_collector: Collector, | 
					
						
							|  |  |  |         max_epoch: int, | 
					
						
							|  |  |  |         step_per_epoch: int, | 
					
						
							|  |  |  |         collect_per_step: int, | 
					
						
							|  |  |  |         repeat_per_collect: int, | 
					
						
							|  |  |  |         episode_per_test: Union[int, List[int]], | 
					
						
							|  |  |  |         batch_size: int, | 
					
						
							|  |  |  |         train_fn: Optional[Callable[[int], None]] = None, | 
					
						
							|  |  |  |         test_fn: Optional[Callable[[int], None]] = None, | 
					
						
							|  |  |  |         stop_fn: Optional[Callable[[float], bool]] = None, | 
					
						
							|  |  |  |         save_fn: Optional[Callable[[BasePolicy], None]] = None, | 
					
						
							|  |  |  |         log_fn: Optional[Callable[[dict], None]] = None, | 
					
						
							|  |  |  |         writer: Optional[SummaryWriter] = None, | 
					
						
							| 
									
										
										
										
											2020-05-16 20:08:32 +08:00
										 |  |  |         log_interval: int = 1, | 
					
						
							|  |  |  |         verbose: bool = True, | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  |         **kwargs | 
					
						
							|  |  |  | ) -> Dict[str, Union[float, str]]: | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |     """A wrapper for on-policy trainer procedure.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` | 
					
						
							|  |  |  |         class. | 
					
						
							|  |  |  |     :param train_collector: the collector used for training. | 
					
						
							|  |  |  |     :type train_collector: :class:`~tianshou.data.Collector` | 
					
						
							|  |  |  |     :param test_collector: the collector used for testing. | 
					
						
							|  |  |  |     :type test_collector: :class:`~tianshou.data.Collector` | 
					
						
							|  |  |  |     :param int max_epoch: the maximum of epochs for training. The training | 
					
						
							|  |  |  |         process might be finished before reaching the ``max_epoch``. | 
					
						
							|  |  |  |     :param int step_per_epoch: the number of step for updating policy network | 
					
						
							|  |  |  |         in one epoch. | 
					
						
							|  |  |  |     :param int collect_per_step: the number of frames the collector would | 
					
						
							|  |  |  |         collect before the network update. In other words, collect some frames | 
					
						
							|  |  |  |         and do one policy network update. | 
					
						
							|  |  |  |     :param int repeat_per_collect: the number of repeat time for policy | 
					
						
							|  |  |  |         learning, for example, set it to 2 means the policy needs to learn each | 
					
						
							|  |  |  |         given batch data twice. | 
					
						
							|  |  |  |     :param episode_per_test: the number of episodes for one policy evaluation. | 
					
						
							|  |  |  |     :type episode_per_test: int or list of ints | 
					
						
							|  |  |  |     :param int batch_size: the batch size of sample data, which is going to | 
					
						
							|  |  |  |         feed in the policy network. | 
					
						
							|  |  |  |     :param function train_fn: a function receives the current number of epoch | 
					
						
							|  |  |  |         index and performs some operations at the beginning of training in this | 
					
						
							|  |  |  |         epoch. | 
					
						
							|  |  |  |     :param function test_fn: a function receives the current number of epoch | 
					
						
							|  |  |  |         index and performs some operations at the beginning of testing in this | 
					
						
							|  |  |  |         epoch. | 
					
						
							| 
									
										
										
										
											2020-04-11 16:54:27 +08:00
										 |  |  |     :param function save_fn: a function for saving policy when the undiscounted | 
					
						
							|  |  |  |         average mean reward in evaluation phase gets better. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param function stop_fn: a function receives the average undiscounted | 
					
						
							|  |  |  |         returns of the testing result, return a boolean which indicates whether | 
					
						
							|  |  |  |         reaching the goal. | 
					
						
							| 
									
										
										
										
											2020-04-10 18:02:05 +08:00
										 |  |  |     :param function log_fn: a function receives env info for logging. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard | 
					
						
							|  |  |  |         SummaryWriter. | 
					
						
							|  |  |  |     :param int log_interval: the log interval of the writer. | 
					
						
							|  |  |  |     :param bool verbose: whether to print the information. | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     :return: See :func:`~tianshou.trainer.gather_info`. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |     global_step = 0 | 
					
						
							|  |  |  |     best_epoch, best_reward = -1, -1 | 
					
						
							|  |  |  |     stat = {} | 
					
						
							|  |  |  |     start_time = time.time() | 
					
						
							| 
									
										
										
										
											2020-04-13 19:37:27 +08:00
										 |  |  |     test_in_train = train_collector.policy == policy | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |     for epoch in range(1, 1 + max_epoch): | 
					
						
							|  |  |  |         # train | 
					
						
							|  |  |  |         policy.train() | 
					
						
							|  |  |  |         if train_fn: | 
					
						
							|  |  |  |             train_fn(epoch) | 
					
						
							| 
									
										
										
										
											2020-04-08 21:13:15 +08:00
										 |  |  |         with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', | 
					
						
							|  |  |  |                        **tqdm_config) as t: | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |             while t.n < t.total: | 
					
						
							| 
									
										
										
										
											2020-04-10 18:02:05 +08:00
										 |  |  |                 result = train_collector.collect(n_episode=collect_per_step, | 
					
						
							|  |  |  |                                                  log_fn=log_fn) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 data = {} | 
					
						
							| 
									
										
										
										
											2020-04-13 19:37:27 +08:00
										 |  |  |                 if test_in_train and stop_fn and stop_fn(result['rew']): | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                     test_result = test_episode( | 
					
						
							|  |  |  |                         policy, test_collector, test_fn, | 
					
						
							|  |  |  |                         epoch, episode_per_test) | 
					
						
							|  |  |  |                     if stop_fn and stop_fn(test_result['rew']): | 
					
						
							| 
									
										
										
										
											2020-04-11 16:54:27 +08:00
										 |  |  |                         if save_fn: | 
					
						
							|  |  |  |                             save_fn(policy) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                         for k in result.keys(): | 
					
						
							|  |  |  |                             data[k] = f'{result[k]:.2f}' | 
					
						
							|  |  |  |                         t.set_postfix(**data) | 
					
						
							|  |  |  |                         return gather_info( | 
					
						
							|  |  |  |                             start_time, train_collector, test_collector, | 
					
						
							|  |  |  |                             test_result['rew']) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         policy.train() | 
					
						
							|  |  |  |                         if train_fn: | 
					
						
							|  |  |  |                             train_fn(epoch) | 
					
						
							|  |  |  |                 losses = policy.learn( | 
					
						
							|  |  |  |                     train_collector.sample(0), batch_size, repeat_per_collect) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                 train_collector.reset_buffer() | 
					
						
							|  |  |  |                 step = 1 | 
					
						
							|  |  |  |                 for k in losses.keys(): | 
					
						
							|  |  |  |                     if isinstance(losses[k], list): | 
					
						
							|  |  |  |                         step = max(step, len(losses[k])) | 
					
						
							|  |  |  |                 global_step += step | 
					
						
							|  |  |  |                 for k in result.keys(): | 
					
						
							|  |  |  |                     data[k] = f'{result[k]:.2f}' | 
					
						
							| 
									
										
										
										
											2020-04-04 11:46:18 +08:00
										 |  |  |                     if writer and global_step % log_interval == 0: | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                         writer.add_scalar( | 
					
						
							| 
									
										
										
										
											2020-05-05 13:39:51 +08:00
										 |  |  |                             k, result[k], global_step=global_step) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                 for k in losses.keys(): | 
					
						
							|  |  |  |                     if stat.get(k) is None: | 
					
						
							|  |  |  |                         stat[k] = MovAvg() | 
					
						
							|  |  |  |                     stat[k].add(losses[k]) | 
					
						
							|  |  |  |                     data[k] = f'{stat[k].get():.6f}' | 
					
						
							| 
									
										
										
										
											2020-04-04 11:46:18 +08:00
										 |  |  |                     if writer and global_step % log_interval == 0: | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                         writer.add_scalar( | 
					
						
							| 
									
										
										
										
											2020-05-05 13:39:51 +08:00
										 |  |  |                             k, stat[k].get(), global_step=global_step) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                 t.update(step) | 
					
						
							|  |  |  |                 t.set_postfix(**data) | 
					
						
							|  |  |  |             if t.n <= t.total: | 
					
						
							|  |  |  |                 t.update() | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         # test | 
					
						
							|  |  |  |         result = test_episode( | 
					
						
							|  |  |  |             policy, test_collector, test_fn, epoch, episode_per_test) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         if best_epoch == -1 or best_reward < result['rew']: | 
					
						
							|  |  |  |             best_reward = result['rew'] | 
					
						
							|  |  |  |             best_epoch = epoch | 
					
						
							| 
									
										
										
										
											2020-04-11 16:54:27 +08:00
										 |  |  |             if save_fn: | 
					
						
							|  |  |  |                 save_fn(policy) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         if verbose: | 
					
						
							|  |  |  |             print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' | 
					
						
							|  |  |  |                   f'best_reward: {best_reward:.6f} in #{best_epoch}') | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         if stop_fn and stop_fn(best_reward): | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |             break | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |     return gather_info( | 
					
						
							|  |  |  |         start_time, train_collector, test_collector, best_reward) |