diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index bc43bae..843ff01 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger): :param int update_interval: the log interval in log_update_data(). Default to 1000. :param int save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). + :param bool write_flush: whether to flush tensorboard result after each + add_scalar operation. Default to True. """ def __init__( @@ -26,16 +28,19 @@ class TensorboardLogger(BaseLogger): test_interval: int = 1, update_interval: int = 1000, save_interval: int = 1, + write_flush: bool = True, ) -> None: super().__init__(train_interval, test_interval, update_interval) self.save_interval = save_interval + self.write_flush = write_flush self.last_save_step = -1 self.writer = writer def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: for k, v in data.items(): self.writer.add_scalar(k, v, global_step=step) - self.writer.flush() # issue #482 + if self.write_flush: # issue 580 + self.writer.flush() # issue #482 def save_data( self, diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 32a89d2..e63a7bc 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -31,6 +31,10 @@ class WandbLogger(BaseLogger): :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. + :param int save_interval: the save interval in save_data(). Default to 1 (save at + the end of each epoch). + :param bool write_flush: whether to flush tensorboard result after each + add_scalar operation. Default to True. :param str project: W&B project name. Default to "tianshou". :param str name: W&B run name. Default to None. If None, random name is assigned. :param str entity: W&B team/organization name. Default to None. @@ -44,6 +48,7 @@ class WandbLogger(BaseLogger): test_interval: int = 1, update_interval: int = 1000, save_interval: int = 1000, + write_flush: bool = True, project: Optional[str] = None, name: Optional[str] = None, entity: Optional[str] = None, @@ -53,6 +58,7 @@ class WandbLogger(BaseLogger): super().__init__(train_interval, test_interval, update_interval) self.last_save_step = -1 self.save_interval = save_interval + self.write_flush = write_flush self.restored = False if project is None: project = os.getenv("WANDB_PROJECT", "tianshou") @@ -72,7 +78,10 @@ class WandbLogger(BaseLogger): def load(self, writer: SummaryWriter) -> None: self.writer = writer - self.tensorboard_logger = TensorboardLogger(writer) + self.tensorboard_logger = TensorboardLogger( + writer, self.train_interval, self.test_interval, self.update_interval, + self.save_interval, self.write_flush + ) def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: if self.tensorboard_logger is None: