Add write_flush in two loggers, fix argument passing in WandbLogger (#581)
This commit is contained in:
parent
6ab9860183
commit
f13e415eb0
@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger):
|
|||||||
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||||
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||||
the end of each epoch).
|
the end of each epoch).
|
||||||
|
:param bool write_flush: whether to flush tensorboard result after each
|
||||||
|
add_scalar operation. Default to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -26,16 +28,19 @@ class TensorboardLogger(BaseLogger):
|
|||||||
test_interval: int = 1,
|
test_interval: int = 1,
|
||||||
update_interval: int = 1000,
|
update_interval: int = 1000,
|
||||||
save_interval: int = 1,
|
save_interval: int = 1,
|
||||||
|
write_flush: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(train_interval, test_interval, update_interval)
|
super().__init__(train_interval, test_interval, update_interval)
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
|
self.write_flush = write_flush
|
||||||
self.last_save_step = -1
|
self.last_save_step = -1
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
|
|
||||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
self.writer.add_scalar(k, v, global_step=step)
|
self.writer.add_scalar(k, v, global_step=step)
|
||||||
self.writer.flush() # issue #482
|
if self.write_flush: # issue 580
|
||||||
|
self.writer.flush() # issue #482
|
||||||
|
|
||||||
def save_data(
|
def save_data(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -31,6 +31,10 @@ class WandbLogger(BaseLogger):
|
|||||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||||
:param int update_interval: the log interval in log_update_data().
|
:param int update_interval: the log interval in log_update_data().
|
||||||
Default to 1000.
|
Default to 1000.
|
||||||
|
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||||
|
the end of each epoch).
|
||||||
|
:param bool write_flush: whether to flush tensorboard result after each
|
||||||
|
add_scalar operation. Default to True.
|
||||||
:param str project: W&B project name. Default to "tianshou".
|
:param str project: W&B project name. Default to "tianshou".
|
||||||
:param str name: W&B run name. Default to None. If None, random name is assigned.
|
:param str name: W&B run name. Default to None. If None, random name is assigned.
|
||||||
:param str entity: W&B team/organization name. Default to None.
|
:param str entity: W&B team/organization name. Default to None.
|
||||||
@ -44,6 +48,7 @@ class WandbLogger(BaseLogger):
|
|||||||
test_interval: int = 1,
|
test_interval: int = 1,
|
||||||
update_interval: int = 1000,
|
update_interval: int = 1000,
|
||||||
save_interval: int = 1000,
|
save_interval: int = 1000,
|
||||||
|
write_flush: bool = True,
|
||||||
project: Optional[str] = None,
|
project: Optional[str] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
entity: Optional[str] = None,
|
entity: Optional[str] = None,
|
||||||
@ -53,6 +58,7 @@ class WandbLogger(BaseLogger):
|
|||||||
super().__init__(train_interval, test_interval, update_interval)
|
super().__init__(train_interval, test_interval, update_interval)
|
||||||
self.last_save_step = -1
|
self.last_save_step = -1
|
||||||
self.save_interval = save_interval
|
self.save_interval = save_interval
|
||||||
|
self.write_flush = write_flush
|
||||||
self.restored = False
|
self.restored = False
|
||||||
if project is None:
|
if project is None:
|
||||||
project = os.getenv("WANDB_PROJECT", "tianshou")
|
project = os.getenv("WANDB_PROJECT", "tianshou")
|
||||||
@ -72,7 +78,10 @@ class WandbLogger(BaseLogger):
|
|||||||
|
|
||||||
def load(self, writer: SummaryWriter) -> None:
|
def load(self, writer: SummaryWriter) -> None:
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
self.tensorboard_logger = TensorboardLogger(writer)
|
self.tensorboard_logger = TensorboardLogger(
|
||||||
|
writer, self.train_interval, self.test_interval, self.update_interval,
|
||||||
|
self.save_interval, self.write_flush
|
||||||
|
)
|
||||||
|
|
||||||
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
||||||
if self.tensorboard_logger is None:
|
if self.tensorboard_logger is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user