Add write_flush in two loggers, fix argument passing in WandbLogger (#581)
This commit is contained in:
parent
6ab9860183
commit
f13e415eb0
@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger):
|
||||
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||
the end of each epoch).
|
||||
:param bool write_flush: whether to flush tensorboard result after each
|
||||
add_scalar operation. Default to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -26,16 +28,19 @@ class TensorboardLogger(BaseLogger):
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
save_interval: int = 1,
|
||||
write_flush: bool = True,
|
||||
) -> None:
|
||||
super().__init__(train_interval, test_interval, update_interval)
|
||||
self.save_interval = save_interval
|
||||
self.write_flush = write_flush
|
||||
self.last_save_step = -1
|
||||
self.writer = writer
|
||||
|
||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||
for k, v in data.items():
|
||||
self.writer.add_scalar(k, v, global_step=step)
|
||||
self.writer.flush() # issue #482
|
||||
if self.write_flush: # issue 580
|
||||
self.writer.flush() # issue #482
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
|
@ -31,6 +31,10 @@ class WandbLogger(BaseLogger):
|
||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||
:param int update_interval: the log interval in log_update_data().
|
||||
Default to 1000.
|
||||
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||
the end of each epoch).
|
||||
:param bool write_flush: whether to flush tensorboard result after each
|
||||
add_scalar operation. Default to True.
|
||||
:param str project: W&B project name. Default to "tianshou".
|
||||
:param str name: W&B run name. Default to None. If None, random name is assigned.
|
||||
:param str entity: W&B team/organization name. Default to None.
|
||||
@ -44,6 +48,7 @@ class WandbLogger(BaseLogger):
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
save_interval: int = 1000,
|
||||
write_flush: bool = True,
|
||||
project: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
entity: Optional[str] = None,
|
||||
@ -53,6 +58,7 @@ class WandbLogger(BaseLogger):
|
||||
super().__init__(train_interval, test_interval, update_interval)
|
||||
self.last_save_step = -1
|
||||
self.save_interval = save_interval
|
||||
self.write_flush = write_flush
|
||||
self.restored = False
|
||||
if project is None:
|
||||
project = os.getenv("WANDB_PROJECT", "tianshou")
|
||||
@ -72,7 +78,10 @@ class WandbLogger(BaseLogger):
|
||||
|
||||
def load(self, writer: SummaryWriter) -> None:
|
||||
self.writer = writer
|
||||
self.tensorboard_logger = TensorboardLogger(writer)
|
||||
self.tensorboard_logger = TensorboardLogger(
|
||||
writer, self.train_interval, self.test_interval, self.update_interval,
|
||||
self.save_interval, self.write_flush
|
||||
)
|
||||
|
||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||
if self.tensorboard_logger is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user