| 
									
										
										
										
											2021-02-24 14:48:42 +08:00
										 |  |  | import numpy as np | 
					
						
							|  |  |  | from numbers import Number | 
					
						
							|  |  |  | from typing import Any, Union | 
					
						
							|  |  |  | from abc import ABC, abstractmethod | 
					
						
							|  |  |  | from torch.utils.tensorboard import SummaryWriter | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BaseLogger(ABC): | 
					
						
							|  |  |  |     """The base class for any logger which is compatible with trainer.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, writer: Any) -> None: | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.writer = writer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def write( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         key: str, | 
					
						
							|  |  |  |         x: Union[Number, np.number, np.ndarray], | 
					
						
							|  |  |  |         y: Union[Number, np.number, np.ndarray], | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							|  |  |  |         """Specify how the writer is used to log data.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param key: namespace which the input data tuple belongs to. | 
					
						
							|  |  |  |         :param x: stands for the ordinate of the input data tuple. | 
					
						
							|  |  |  |         :param y: stands for the abscissa of the input data tuple. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_train_data(self, collect_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during training.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param collect_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             training stage, i.e., returns of collector.collect(). | 
					
						
							|  |  |  |         :param int step: stands for the timestep the collect_result being logged. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_update_data(self, update_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during updating.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param update_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             updating stage, i.e., returns of policy.update(). | 
					
						
							|  |  |  |         :param int step: stands for the timestep the collect_result being logged. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_test_data(self, collect_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during evaluating.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param collect_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             evaluating stage, i.e., returns of collector.collect(). | 
					
						
							|  |  |  |         :param int step: stands for the timestep the collect_result being logged. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BasicLogger(BaseLogger): | 
					
						
							|  |  |  |     """A loggger that relies on tensorboard SummaryWriter by default to visualize \
 | 
					
						
							|  |  |  |     and log statistics. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     You can also rewrite write() func to use your own writer. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     :param SummaryWriter writer: the writer to log data. | 
					
						
							|  |  |  |     :param int train_interval: the log interval in log_train_data(). Default to 1. | 
					
						
							|  |  |  |     :param int test_interval: the log interval in log_test_data(). Default to 1. | 
					
						
							|  |  |  |     :param int update_interval: the log interval in log_update_data(). Default to 1000. | 
					
						
							| 
									
										
										
										
											2020-10-31 16:38:54 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-24 14:48:42 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         writer: SummaryWriter, | 
					
						
							|  |  |  |         train_interval: int = 1, | 
					
						
							|  |  |  |         test_interval: int = 1, | 
					
						
							|  |  |  |         update_interval: int = 1000, | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							|  |  |  |         super().__init__(writer) | 
					
						
							|  |  |  |         self.train_interval = train_interval | 
					
						
							|  |  |  |         self.test_interval = test_interval | 
					
						
							|  |  |  |         self.update_interval = update_interval | 
					
						
							|  |  |  |         self.last_log_train_step = -1 | 
					
						
							|  |  |  |         self.last_log_test_step = -1 | 
					
						
							|  |  |  |         self.last_log_update_step = -1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def write( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         key: str, | 
					
						
							|  |  |  |         x: Union[Number, np.number, np.ndarray], | 
					
						
							|  |  |  |         y: Union[Number, np.number, np.ndarray], | 
					
						
							|  |  |  |         **kwargs: Any, | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							|  |  |  |         self.writer.add_scalar(key, y, global_step=x) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_train_data(self, collect_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during training.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param collect_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             training stage, i.e., returns of collector.collect(). | 
					
						
							|  |  |  |         :param int step: stands for the timestep the collect_result being logged. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         .. note:: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             ``collect_result`` will be modified in-place with "rew" and "len" keys. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if collect_result["n/ep"] > 0: | 
					
						
							|  |  |  |             collect_result["rew"] = collect_result["rews"].mean() | 
					
						
							|  |  |  |             collect_result["len"] = collect_result["lens"].mean() | 
					
						
							|  |  |  |             if step - self.last_log_train_step >= self.train_interval: | 
					
						
							|  |  |  |                 self.write("train/n/ep", step, collect_result["n/ep"]) | 
					
						
							|  |  |  |                 self.write("train/rew", step, collect_result["rew"]) | 
					
						
							|  |  |  |                 self.write("train/len", step, collect_result["len"]) | 
					
						
							|  |  |  |                 self.last_log_train_step = step | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_test_data(self, collect_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during evaluating.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param collect_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             evaluating stage, i.e., returns of collector.collect(). | 
					
						
							|  |  |  |         :param int step: stands for the timestep the collect_result being logged. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         .. note:: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             ``collect_result`` will be modified in-place with "rew", "rew_std", "len", | 
					
						
							|  |  |  |             and "len_std" keys. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         assert collect_result["n/ep"] > 0 | 
					
						
							|  |  |  |         rews, lens = collect_result["rews"], collect_result["lens"] | 
					
						
							|  |  |  |         rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() | 
					
						
							|  |  |  |         collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) | 
					
						
							|  |  |  |         if step - self.last_log_test_step >= self.test_interval: | 
					
						
							|  |  |  |             self.write("test/rew", step, rew) | 
					
						
							|  |  |  |             self.write("test/len", step, len_) | 
					
						
							|  |  |  |             self.write("test/rew_std", step, rew_std) | 
					
						
							|  |  |  |             self.write("test/len_std", step, len_std) | 
					
						
							|  |  |  |             self.last_log_test_step = step | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_update_data(self, update_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         if step - self.last_log_update_step >= self.update_interval: | 
					
						
							|  |  |  |             for k, v in update_result.items(): | 
					
						
							| 
									
										
										
										
											2021-02-26 13:23:18 +08:00
										 |  |  |                 self.write(k, step, v) | 
					
						
							| 
									
										
										
										
											2021-02-24 14:48:42 +08:00
										 |  |  |             self.last_log_update_step = step | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LazyLogger(BasicLogger): | 
					
						
							|  |  |  |     """A loggger that does nothing. Used as the placeholder in trainer.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self) -> None: | 
					
						
							|  |  |  |         super().__init__(None)  # type: ignore | 
					
						
							| 
									
										
										
										
											2020-10-31 16:38:54 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-24 14:48:42 +08:00
										 |  |  |     def write( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         key: str, | 
					
						
							|  |  |  |         x: Union[Number, np.number, np.ndarray], | 
					
						
							|  |  |  |         y: Union[Number, np.number, np.ndarray], | 
					
						
							| 
									
										
										
										
											2020-10-31 16:38:54 +08:00
										 |  |  |         **kwargs: Any, | 
					
						
							| 
									
										
										
										
											2021-02-24 14:48:42 +08:00
										 |  |  |     ) -> None: | 
					
						
							|  |  |  |         """The LazyLogger writes nothing.""" | 
					
						
							|  |  |  |         pass |