Add write_flush in two loggers, fix argument passing in WandbLogger (#581)
This commit is contained in:
		
							parent
							
								
									6ab9860183
								
							
						
					
					
						commit
						f13e415eb0
					
				| @ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger): | |||||||
|     :param int update_interval: the log interval in log_update_data(). Default to 1000. |     :param int update_interval: the log interval in log_update_data(). Default to 1000. | ||||||
|     :param int save_interval: the save interval in save_data(). Default to 1 (save at |     :param int save_interval: the save interval in save_data(). Default to 1 (save at | ||||||
|         the end of each epoch). |         the end of each epoch). | ||||||
|  |     :param bool write_flush: whether to flush tensorboard result after each | ||||||
|  |         add_scalar operation. Default to True. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
| @ -26,16 +28,19 @@ class TensorboardLogger(BaseLogger): | |||||||
|         test_interval: int = 1, |         test_interval: int = 1, | ||||||
|         update_interval: int = 1000, |         update_interval: int = 1000, | ||||||
|         save_interval: int = 1, |         save_interval: int = 1, | ||||||
|  |         write_flush: bool = True, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         super().__init__(train_interval, test_interval, update_interval) |         super().__init__(train_interval, test_interval, update_interval) | ||||||
|         self.save_interval = save_interval |         self.save_interval = save_interval | ||||||
|  |         self.write_flush = write_flush | ||||||
|         self.last_save_step = -1 |         self.last_save_step = -1 | ||||||
|         self.writer = writer |         self.writer = writer | ||||||
| 
 | 
 | ||||||
|     def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: |     def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: | ||||||
|         for k, v in data.items(): |         for k, v in data.items(): | ||||||
|             self.writer.add_scalar(k, v, global_step=step) |             self.writer.add_scalar(k, v, global_step=step) | ||||||
|         self.writer.flush()  # issue #482 |         if self.write_flush:  # issue 580 | ||||||
|  |             self.writer.flush()  # issue #482 | ||||||
| 
 | 
 | ||||||
|     def save_data( |     def save_data( | ||||||
|         self, |         self, | ||||||
|  | |||||||
| @ -31,6 +31,10 @@ class WandbLogger(BaseLogger): | |||||||
|     :param int test_interval: the log interval in log_test_data(). Default to 1. |     :param int test_interval: the log interval in log_test_data(). Default to 1. | ||||||
|     :param int update_interval: the log interval in log_update_data(). |     :param int update_interval: the log interval in log_update_data(). | ||||||
|         Default to 1000. |         Default to 1000. | ||||||
|  |     :param int save_interval: the save interval in save_data(). Default to 1 (save at | ||||||
|  |         the end of each epoch). | ||||||
|  |     :param bool write_flush: whether to flush tensorboard result after each | ||||||
|  |         add_scalar operation. Default to True. | ||||||
|     :param str project: W&B project name. Default to "tianshou". |     :param str project: W&B project name. Default to "tianshou". | ||||||
|     :param str name: W&B run name. Default to None. If None, random name is assigned. |     :param str name: W&B run name. Default to None. If None, random name is assigned. | ||||||
|     :param str entity: W&B team/organization name. Default to None. |     :param str entity: W&B team/organization name. Default to None. | ||||||
| @ -44,6 +48,7 @@ class WandbLogger(BaseLogger): | |||||||
|         test_interval: int = 1, |         test_interval: int = 1, | ||||||
|         update_interval: int = 1000, |         update_interval: int = 1000, | ||||||
|         save_interval: int = 1000, |         save_interval: int = 1000, | ||||||
|  |         write_flush: bool = True, | ||||||
|         project: Optional[str] = None, |         project: Optional[str] = None, | ||||||
|         name: Optional[str] = None, |         name: Optional[str] = None, | ||||||
|         entity: Optional[str] = None, |         entity: Optional[str] = None, | ||||||
| @ -53,6 +58,7 @@ class WandbLogger(BaseLogger): | |||||||
|         super().__init__(train_interval, test_interval, update_interval) |         super().__init__(train_interval, test_interval, update_interval) | ||||||
|         self.last_save_step = -1 |         self.last_save_step = -1 | ||||||
|         self.save_interval = save_interval |         self.save_interval = save_interval | ||||||
|  |         self.write_flush = write_flush | ||||||
|         self.restored = False |         self.restored = False | ||||||
|         if project is None: |         if project is None: | ||||||
|             project = os.getenv("WANDB_PROJECT", "tianshou") |             project = os.getenv("WANDB_PROJECT", "tianshou") | ||||||
| @ -72,7 +78,10 @@ class WandbLogger(BaseLogger): | |||||||
| 
 | 
 | ||||||
|     def load(self, writer: SummaryWriter) -> None: |     def load(self, writer: SummaryWriter) -> None: | ||||||
|         self.writer = writer |         self.writer = writer | ||||||
|         self.tensorboard_logger = TensorboardLogger(writer) |         self.tensorboard_logger = TensorboardLogger( | ||||||
|  |             writer, self.train_interval, self.test_interval, self.update_interval, | ||||||
|  |             self.save_interval, self.write_flush | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: |     def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: | ||||||
|         if self.tensorboard_logger is None: |         if self.tensorboard_logger is None: | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user