Add write_flush in two loggers, fix argument passing in WandbLogger (#581)

This commit is contained in:
Jiayi Weng 2022-03-29 20:04:23 -04:00 committed by GitHub
parent 6ab9860183
commit f13e415eb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger):
:param int update_interval: the log interval in log_update_data(). Default to 1000. :param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at :param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch). the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
""" """
def __init__( def __init__(
@ -26,15 +28,18 @@ class TensorboardLogger(BaseLogger):
test_interval: int = 1, test_interval: int = 1,
update_interval: int = 1000, update_interval: int = 1000,
save_interval: int = 1, save_interval: int = 1,
write_flush: bool = True,
) -> None: ) -> None:
super().__init__(train_interval, test_interval, update_interval) super().__init__(train_interval, test_interval, update_interval)
self.save_interval = save_interval self.save_interval = save_interval
self.write_flush = write_flush
self.last_save_step = -1 self.last_save_step = -1
self.writer = writer self.writer = writer
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
for k, v in data.items(): for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step) self.writer.add_scalar(k, v, global_step=step)
if self.write_flush: # issue 580
self.writer.flush() # issue #482 self.writer.flush() # issue #482
def save_data( def save_data(

View File

@ -31,6 +31,10 @@ class WandbLogger(BaseLogger):
:param int test_interval: the log interval in log_test_data(). Default to 1. :param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). :param int update_interval: the log interval in log_update_data().
Default to 1000. Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
:param str project: W&B project name. Default to "tianshou". :param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned. :param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None. :param str entity: W&B team/organization name. Default to None.
@ -44,6 +48,7 @@ class WandbLogger(BaseLogger):
test_interval: int = 1, test_interval: int = 1,
update_interval: int = 1000, update_interval: int = 1000,
save_interval: int = 1000, save_interval: int = 1000,
write_flush: bool = True,
project: Optional[str] = None, project: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
entity: Optional[str] = None, entity: Optional[str] = None,
@ -53,6 +58,7 @@ class WandbLogger(BaseLogger):
super().__init__(train_interval, test_interval, update_interval) super().__init__(train_interval, test_interval, update_interval)
self.last_save_step = -1 self.last_save_step = -1
self.save_interval = save_interval self.save_interval = save_interval
self.write_flush = write_flush
self.restored = False self.restored = False
if project is None: if project is None:
project = os.getenv("WANDB_PROJECT", "tianshou") project = os.getenv("WANDB_PROJECT", "tianshou")
@ -72,7 +78,10 @@ class WandbLogger(BaseLogger):
def load(self, writer: SummaryWriter) -> None: def load(self, writer: SummaryWriter) -> None:
self.writer = writer self.writer = writer
self.tensorboard_logger = TensorboardLogger(writer) self.tensorboard_logger = TensorboardLogger(
writer, self.train_interval, self.test_interval, self.update_interval,
self.save_interval, self.write_flush
)
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
if self.tensorboard_logger is None: if self.tensorboard_logger is None: