| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | import time | 
					
						
							|  |  |  | import tqdm | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | from torch.utils.tensorboard import SummaryWriter | 
					
						
							|  |  |  | from typing import Dict, List, Union, Callable, Optional | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | from tianshou.data import Collector | 
					
						
							|  |  |  | from tianshou.policy import BasePolicy | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | from tianshou.utils import tqdm_config, MovAvg | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | from tianshou.trainer import test_episode, gather_info | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | def onpolicy_trainer( | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     policy: BasePolicy, | 
					
						
							|  |  |  |     train_collector: Collector, | 
					
						
							|  |  |  |     test_collector: Collector, | 
					
						
							|  |  |  |     max_epoch: int, | 
					
						
							|  |  |  |     step_per_epoch: int, | 
					
						
							|  |  |  |     collect_per_step: int, | 
					
						
							|  |  |  |     repeat_per_collect: int, | 
					
						
							|  |  |  |     episode_per_test: Union[int, List[int]], | 
					
						
							|  |  |  |     batch_size: int, | 
					
						
							| 
									
										
										
										
											2020-09-26 16:35:37 +08:00
										 |  |  |     train_fn: Optional[Callable[[int, int], None]] = None, | 
					
						
							|  |  |  |     test_fn: Optional[Callable[[int, Optional[int]], None]] = None, | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     stop_fn: Optional[Callable[[float], bool]] = None, | 
					
						
							|  |  |  |     save_fn: Optional[Callable[[BasePolicy], None]] = None, | 
					
						
							|  |  |  |     writer: Optional[SummaryWriter] = None, | 
					
						
							|  |  |  |     log_interval: int = 1, | 
					
						
							|  |  |  |     verbose: bool = True, | 
					
						
							|  |  |  |     test_in_train: bool = True, | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | ) -> Dict[str, Union[float, str]]: | 
					
						
							| 
									
										
										
										
											2020-09-11 07:55:37 +08:00
										 |  |  |     """A wrapper for on-policy trainer procedure.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     The "step" in trainer means a policy network update. | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` | 
					
						
							|  |  |  |         class. | 
					
						
							|  |  |  |     :param train_collector: the collector used for training. | 
					
						
							|  |  |  |     :type train_collector: :class:`~tianshou.data.Collector` | 
					
						
							|  |  |  |     :param test_collector: the collector used for testing. | 
					
						
							|  |  |  |     :type test_collector: :class:`~tianshou.data.Collector` | 
					
						
							|  |  |  |     :param int max_epoch: the maximum of epochs for training. The training | 
					
						
							|  |  |  |         process might be finished before reaching the ``max_epoch``. | 
					
						
							|  |  |  |     :param int step_per_epoch: the number of step for updating policy network | 
					
						
							|  |  |  |         in one epoch. | 
					
						
							| 
									
										
										
										
											2020-07-27 16:54:14 +08:00
										 |  |  |     :param int collect_per_step: the number of episodes the collector would | 
					
						
							|  |  |  |         collect before the network update. In other words, collect some | 
					
						
							|  |  |  |         episodes and do one policy network update. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param int repeat_per_collect: the number of repeat time for policy | 
					
						
							|  |  |  |         learning, for example, set it to 2 means the policy needs to learn each | 
					
						
							|  |  |  |         given batch data twice. | 
					
						
							|  |  |  |     :param episode_per_test: the number of episodes for one policy evaluation. | 
					
						
							|  |  |  |     :type episode_per_test: int or list of ints | 
					
						
							|  |  |  |     :param int batch_size: the batch size of sample data, which is going to | 
					
						
							|  |  |  |         feed in the policy network. | 
					
						
							|  |  |  |     :param function train_fn: a function receives the current number of epoch | 
					
						
							| 
									
										
										
										
											2020-09-26 16:35:37 +08:00
										 |  |  |         and step index, and performs some operations at the beginning of | 
					
						
							|  |  |  |         training in this poch. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param function test_fn: a function receives the current number of epoch | 
					
						
							| 
									
										
										
										
											2020-09-26 16:35:37 +08:00
										 |  |  |         and step index, and performs some operations at the beginning of | 
					
						
							|  |  |  |         testing in this epoch. | 
					
						
							| 
									
										
										
										
											2020-04-11 16:54:27 +08:00
										 |  |  |     :param function save_fn: a function for saving policy when the undiscounted | 
					
						
							|  |  |  |         average mean reward in evaluation phase gets better. | 
					
						
							| 
									
										
										
										
											2020-04-06 19:36:59 +08:00
										 |  |  |     :param function stop_fn: a function receives the average undiscounted | 
					
						
							|  |  |  |         returns of the testing result, return a boolean which indicates whether | 
					
						
							|  |  |  |         reaching the goal. | 
					
						
							|  |  |  |     :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard | 
					
						
							|  |  |  |         SummaryWriter. | 
					
						
							|  |  |  |     :param int log_interval: the log interval of the writer. | 
					
						
							|  |  |  |     :param bool verbose: whether to print the information. | 
					
						
							| 
									
										
										
										
											2020-07-21 14:59:49 +08:00
										 |  |  |     :param bool test_in_train: whether to test in the training phase. | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     :return: See :func:`~tianshou.trainer.gather_info`. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |     global_step = 0 | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |     best_epoch, best_reward = -1, -1.0 | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |     stat: Dict[str, MovAvg] = {} | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |     start_time = time.time() | 
					
						
							| 
									
										
										
										
											2020-09-22 16:28:46 +08:00
										 |  |  |     train_collector.reset_stat() | 
					
						
							|  |  |  |     test_collector.reset_stat() | 
					
						
							| 
									
										
										
										
											2020-07-21 14:59:49 +08:00
										 |  |  |     test_in_train = test_in_train and train_collector.policy == policy | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |     for epoch in range(1, 1 + max_epoch): | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |         # train | 
					
						
							|  |  |  |         policy.train() | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         with tqdm.tqdm( | 
					
						
							|  |  |  |             total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config | 
					
						
							|  |  |  |         ) as t: | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |             while t.n < t.total: | 
					
						
							| 
									
										
										
										
											2020-09-26 16:35:37 +08:00
										 |  |  |                 if train_fn: | 
					
						
							|  |  |  |                     train_fn(epoch, global_step) | 
					
						
							| 
									
										
										
										
											2020-07-23 16:40:53 +08:00
										 |  |  |                 result = train_collector.collect(n_episode=collect_per_step) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                 data = {} | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                 if test_in_train and stop_fn and stop_fn(result["rew"]): | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                     test_result = test_episode( | 
					
						
							|  |  |  |                         policy, test_collector, test_fn, | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |                         epoch, episode_per_test, writer, global_step) | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |                     if stop_fn(test_result["rew"]): | 
					
						
							| 
									
										
										
										
											2020-04-11 16:54:27 +08:00
										 |  |  |                         if save_fn: | 
					
						
							|  |  |  |                             save_fn(policy) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                         for k in result.keys(): | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                             data[k] = f"{result[k]:.2f}" | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                         t.set_postfix(**data) | 
					
						
							|  |  |  |                         return gather_info( | 
					
						
							|  |  |  |                             start_time, train_collector, test_collector, | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                             test_result["rew"]) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |                     else: | 
					
						
							| 
									
										
										
										
											2020-09-12 08:44:50 +08:00
										 |  |  |                         policy.train() | 
					
						
							| 
									
										
										
										
											2020-08-15 16:10:42 +08:00
										 |  |  |                 losses = policy.update( | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                     0, train_collector.buffer, | 
					
						
							|  |  |  |                     batch_size=batch_size, repeat=repeat_per_collect) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                 train_collector.reset_buffer() | 
					
						
							|  |  |  |                 step = 1 | 
					
						
							| 
									
										
										
										
											2020-09-13 19:31:50 +08:00
										 |  |  |                 for v in losses.values(): | 
					
						
							|  |  |  |                     if isinstance(v, list): | 
					
						
							|  |  |  |                         step = max(step, len(v)) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |                 global_step += step * collect_per_step | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                 for k in result.keys(): | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                     data[k] = f"{result[k]:.2f}" | 
					
						
							| 
									
										
										
										
											2020-04-04 11:46:18 +08:00
										 |  |  |                     if writer and global_step % log_interval == 0: | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                         writer.add_scalar( | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                             "train/" + k, result[k], global_step=global_step) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                 for k in losses.keys(): | 
					
						
							|  |  |  |                     if stat.get(k) is None: | 
					
						
							|  |  |  |                         stat[k] = MovAvg() | 
					
						
							|  |  |  |                     stat[k].add(losses[k]) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |                     data[k] = f"{stat[k].get():.6f}" | 
					
						
							| 
									
										
										
										
											2020-04-04 11:46:18 +08:00
										 |  |  |                     if writer and global_step % log_interval == 0: | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                         writer.add_scalar( | 
					
						
							| 
									
										
										
										
											2020-05-05 13:39:51 +08:00
										 |  |  |                             k, stat[k].get(), global_step=global_step) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |                 t.update(step) | 
					
						
							|  |  |  |                 t.set_postfix(**data) | 
					
						
							|  |  |  |             if t.n <= t.total: | 
					
						
							|  |  |  |                 t.update() | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         # test | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |         result = test_episode(policy, test_collector, test_fn, epoch, | 
					
						
							|  |  |  |                               episode_per_test, writer, global_step) | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |         if best_epoch == -1 or best_reward < result["rew"]: | 
					
						
							|  |  |  |             best_reward = result["rew"] | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |             best_epoch = epoch | 
					
						
							| 
									
										
										
										
											2020-04-11 16:54:27 +08:00
										 |  |  |             if save_fn: | 
					
						
							|  |  |  |                 save_fn(policy) | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |         if verbose: | 
					
						
							| 
									
										
										
										
											2020-09-12 15:39:01 +08:00
										 |  |  |             print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " | 
					
						
							|  |  |  |                   f"best_reward: {best_reward:.6f} in #{best_epoch}") | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |         if stop_fn and stop_fn(best_reward): | 
					
						
							| 
									
										
										
										
											2020-03-19 17:23:46 +08:00
										 |  |  |             break | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |     return gather_info( | 
					
						
							|  |  |  |         start_time, train_collector, test_collector, best_reward) |