Add write_flush in two loggers, fix argument passing in WandbLogger (#581)

This commit is contained in:
Jiayi Weng 2022-03-29 20:04:23 -04:00 committed by GitHub
parent 6ab9860183
commit f13e415eb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger):
:param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
"""
def __init__(
@ -26,16 +28,19 @@ class TensorboardLogger(BaseLogger):
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1,
write_flush: bool = True,
) -> None:
super().__init__(train_interval, test_interval, update_interval)
self.save_interval = save_interval
self.write_flush = write_flush
self.last_save_step = -1
self.writer = writer
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step)
self.writer.flush() # issue #482
if self.write_flush: # issue 580
self.writer.flush() # issue #482
def save_data(
self,

View File

@ -31,6 +31,10 @@ class WandbLogger(BaseLogger):
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data().
Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
:param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None.
@ -44,6 +48,7 @@ class WandbLogger(BaseLogger):
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1000,
write_flush: bool = True,
project: Optional[str] = None,
name: Optional[str] = None,
entity: Optional[str] = None,
@ -53,6 +58,7 @@ class WandbLogger(BaseLogger):
super().__init__(train_interval, test_interval, update_interval)
self.last_save_step = -1
self.save_interval = save_interval
self.write_flush = write_flush
self.restored = False
if project is None:
project = os.getenv("WANDB_PROJECT", "tianshou")
@ -72,7 +78,10 @@ class WandbLogger(BaseLogger):
def load(self, writer: SummaryWriter) -> None:
self.writer = writer
self.tensorboard_logger = TensorboardLogger(writer)
self.tensorboard_logger = TensorboardLogger(
writer, self.train_interval, self.test_interval, self.update_interval,
self.save_interval, self.write_flush
)
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
if self.tensorboard_logger is None: