| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2020-03-27 09:04:29 +08:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | from typing import Dict, List, Union, Callable | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | from tianshou.data import Collector | 
					
						
							|  |  |  | from tianshou.policy import BasePolicy | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def test_episode( | 
					
						
							|  |  |  |         policy: BasePolicy, | 
					
						
							|  |  |  |         collector: Collector, | 
					
						
							|  |  |  |         test_fn: Callable[[int], None], | 
					
						
							|  |  |  |         epoch: int, | 
					
						
							|  |  |  |         n_episode: Union[int, List[int]]) -> Dict[str, float]: | 
					
						
							| 
									
										
										
										
											2020-04-04 21:02:06 +08:00
										 |  |  |     """A simple wrapper of testing policy in collector.""" | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |     collector.reset_env() | 
					
						
							|  |  |  |     collector.reset_buffer() | 
					
						
							|  |  |  |     policy.eval() | 
					
						
							|  |  |  |     if test_fn: | 
					
						
							|  |  |  |         test_fn(epoch) | 
					
						
							| 
									
										
										
										
											2020-03-27 09:04:29 +08:00
										 |  |  |     if collector.get_env_num() > 1 and np.isscalar(n_episode): | 
					
						
							|  |  |  |         n = collector.get_env_num() | 
					
						
							|  |  |  |         n_ = np.zeros(n) + n_episode // n | 
					
						
							|  |  |  |         n_[:n_episode % n] += 1 | 
					
						
							|  |  |  |         n_episode = list(n_) | 
					
						
							| 
									
										
										
										
											2020-03-26 17:32:51 +08:00
										 |  |  |     return collector.collect(n_episode=n_episode) | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-12 11:31:47 +08:00
										 |  |  | def gather_info(start_time: float, | 
					
						
							|  |  |  |                 train_c: Collector, | 
					
						
							|  |  |  |                 test_c: Collector, | 
					
						
							|  |  |  |                 best_reward: float | 
					
						
							|  |  |  |                 ) -> Dict[str, Union[float, str]]: | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |     """A simple wrapper of gathering information from collectors.
 | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |     :return: A dictionary with the following keys: | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |         * ``train_step`` the total collected step of training collector; | 
					
						
							|  |  |  |         * ``train_episode`` the total collected episode of training collector; | 
					
						
							|  |  |  |         * ``train_time/collector`` the time for collecting frames in the \ | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |             training collector; | 
					
						
							| 
									
										
										
										
											2020-04-05 18:34:45 +08:00
										 |  |  |         * ``train_time/model`` the time for training models; | 
					
						
							|  |  |  |         * ``train_speed`` the speed of training (frames per second); | 
					
						
							|  |  |  |         * ``test_step`` the total collected step of test collector; | 
					
						
							|  |  |  |         * ``test_episode`` the total collected episode of test collector; | 
					
						
							|  |  |  |         * ``test_time`` the time for testing; | 
					
						
							|  |  |  |         * ``test_speed`` the speed of testing (frames per second); | 
					
						
							|  |  |  |         * ``best_reward`` the best reward over the test results; | 
					
						
							|  |  |  |         * ``duration`` the total elapsed time. | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-03-20 19:52:29 +08:00
										 |  |  |     duration = time.time() - start_time | 
					
						
							|  |  |  |     model_time = duration - train_c.collect_time - test_c.collect_time | 
					
						
							|  |  |  |     train_speed = train_c.collect_step / (duration - test_c.collect_time) | 
					
						
							|  |  |  |     test_speed = test_c.collect_step / test_c.collect_time | 
					
						
							|  |  |  |     return { | 
					
						
							|  |  |  |         'train_step': train_c.collect_step, | 
					
						
							|  |  |  |         'train_episode': train_c.collect_episode, | 
					
						
							|  |  |  |         'train_time/collector': f'{train_c.collect_time:.2f}s', | 
					
						
							|  |  |  |         'train_time/model': f'{model_time:.2f}s', | 
					
						
							|  |  |  |         'train_speed': f'{train_speed:.2f} step/s', | 
					
						
							|  |  |  |         'test_step': test_c.collect_step, | 
					
						
							|  |  |  |         'test_episode': test_c.collect_episode, | 
					
						
							|  |  |  |         'test_time': f'{test_c.collect_time:.2f}s', | 
					
						
							|  |  |  |         'test_speed': f'{test_speed:.2f} step/s', | 
					
						
							|  |  |  |         'best_reward': best_reward, | 
					
						
							|  |  |  |         'duration': f'{duration:.2f}s', | 
					
						
							|  |  |  |     } |