| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  | from abc import ABC, abstractmethod | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  | from collections.abc import Callable | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | from numbers import Number | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  | LOG_DATA_TYPE = dict[str, int | Number | np.number | np.ndarray] | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BaseLogger(ABC): | 
					
						
							|  |  |  |     """The base class for any logger which is compatible with trainer.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Try to overwrite write() method to use your own writer. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  |     :param train_interval: the log interval in log_train_data(). Default to 1000. | 
					
						
							|  |  |  |     :param test_interval: the log interval in log_test_data(). Default to 1. | 
					
						
							|  |  |  |     :param update_interval: the log interval in log_update_data(). Default to 1000. | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         train_interval: int = 1000, | 
					
						
							|  |  |  |         test_interval: int = 1, | 
					
						
							|  |  |  |         update_interval: int = 1000, | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.train_interval = train_interval | 
					
						
							|  |  |  |         self.test_interval = test_interval | 
					
						
							|  |  |  |         self.update_interval = update_interval | 
					
						
							|  |  |  |         self.last_log_train_step = -1 | 
					
						
							|  |  |  |         self.last_log_test_step = -1 | 
					
						
							|  |  |  |         self.last_log_update_step = -1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @abstractmethod | 
					
						
							|  |  |  |     def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: | 
					
						
							|  |  |  |         """Specify how the writer is used to log data.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param str step_type: namespace which the data dict belongs to. | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  |         :param step: stands for the ordinate of the data dict. | 
					
						
							|  |  |  |         :param data: the data to write with format ``{key: value}``. | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |         """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_train_data(self, collect_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during training.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param collect_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             training stage, i.e., returns of collector.collect(). | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  |         :param step: stands for the timestep the collect_result being logged. | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |         if collect_result["n/ep"] > 0 and step - self.last_log_train_step >= self.train_interval: | 
					
						
							|  |  |  |             log_data = { | 
					
						
							|  |  |  |                 "train/episode": collect_result["n/ep"], | 
					
						
							|  |  |  |                 "train/reward": collect_result["rew"], | 
					
						
							|  |  |  |                 "train/length": collect_result["len"], | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             self.write("train/env_step", step, log_data) | 
					
						
							|  |  |  |             self.last_log_train_step = step | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def log_test_data(self, collect_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during evaluating.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param collect_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             evaluating stage, i.e., returns of collector.collect(). | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  |         :param step: stands for the timestep the collect_result being logged. | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         assert collect_result["n/ep"] > 0 | 
					
						
							|  |  |  |         if step - self.last_log_test_step >= self.test_interval: | 
					
						
							|  |  |  |             log_data = { | 
					
						
							|  |  |  |                 "test/env_step": step, | 
					
						
							| 
									
										
										
										
											2021-10-13 09:25:24 -04:00
										 |  |  |                 "test/reward": collect_result["rew"], | 
					
						
							|  |  |  |                 "test/length": collect_result["len"], | 
					
						
							|  |  |  |                 "test/reward_std": collect_result["rew_std"], | 
					
						
							|  |  |  |                 "test/length_std": collect_result["len_std"], | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |             } | 
					
						
							|  |  |  |             self.write("test/env_step", step, log_data) | 
					
						
							|  |  |  |             self.last_log_test_step = step | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def log_update_data(self, update_result: dict, step: int) -> None: | 
					
						
							|  |  |  |         """Use writer to log statistics generated during updating.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param update_result: a dict containing information of data collected in | 
					
						
							|  |  |  |             updating stage, i.e., returns of policy.update(). | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  |         :param step: stands for the timestep the collect_result being logged. | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         if step - self.last_log_update_step >= self.update_interval: | 
					
						
							|  |  |  |             log_data = {f"update/{k}": v for k, v in update_result.items()} | 
					
						
							|  |  |  |             self.write("update/gradient_step", step, log_data) | 
					
						
							|  |  |  |             self.last_log_update_step = step | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |     def save_data( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         epoch: int, | 
					
						
							|  |  |  |         env_step: int, | 
					
						
							|  |  |  |         gradient_step: int, | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |         save_checkpoint_fn: Callable[[int, int, int], str] | None = None, | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |     ) -> None: | 
					
						
							|  |  |  |         """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-08 17:57:03 +02:00
										 |  |  |         :param epoch: the epoch in trainer. | 
					
						
							|  |  |  |         :param env_step: the env_step in trainer. | 
					
						
							|  |  |  |         :param gradient_step: the gradient_step in trainer. | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |         :param function save_checkpoint_fn: a hook defined by user, see trainer | 
					
						
							|  |  |  |             documentation for detail. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     @abstractmethod | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     def restore_data(self) -> tuple[int, int, int]: | 
					
						
							| 
									
										
										
										
											2021-08-30 10:35:02 -04:00
										 |  |  |         """Return the metadata from existing log.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         If it finds nothing or an error occurs during the recover process, it will | 
					
						
							|  |  |  |         return the default parameters. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: epoch, env_step, gradient_step. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class LazyLogger(BaseLogger): | 
					
						
							|  |  |  |     """A logger that does nothing. Used as the placeholder in trainer.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self) -> None: | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: | 
					
						
							|  |  |  |         """The LazyLogger writes nothing.""" | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def save_data( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         epoch: int, | 
					
						
							|  |  |  |         env_step: int, | 
					
						
							|  |  |  |         gradient_step: int, | 
					
						
							| 
									
										
										
										
											2023-09-05 23:34:23 +02:00
										 |  |  |         save_checkpoint_fn: Callable[[int, int, int], str] | None = None, | 
					
						
							| 
									
										
										
										
											2022-10-31 08:54:54 +09:00
										 |  |  |     ) -> None: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-25 23:40:56 +02:00
										 |  |  |     def restore_data(self) -> tuple[int, int, int]: | 
					
						
							| 
									
										
										
										
											2022-11-11 20:25:35 +00:00
										 |  |  |         return 0, 0, 0 |