small fix

This commit is contained in:
Maximilian Huettenrauch 2024-03-12 15:17:33 +01:00
parent d9a612a997
commit a7898b15b8
4 changed files with 22 additions and 6 deletions

View File

@ -58,9 +58,15 @@ class BaseLogger(ABC):
:param data: the data to write with format ``{key: value}``. :param data: the data to write with format ``{key: value}``.
""" """
@staticmethod @abstractmethod
def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
return log_data """Prepare the dict for logging by filtering out invalid data types.
If necessary, reformulate the dict to be compatible with the writer.
:param log_data: the dict to be prepared for logging.
:return: the prepared dict.
"""
def log_train_data(self, log_data: dict, step: int) -> None: def log_train_data(self, log_data: dict, step: int) -> None:
"""Use writer to log statistics generated during training. """Use writer to log statistics generated during training.
@ -152,6 +158,12 @@ class LazyLogger(BaseLogger):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def prepare_dict_for_logging(
self,
data: dict[str, VALID_LOG_VALS_TYPE],
) -> dict[str, VALID_LOG_VALS_TYPE]:
return data
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
"""The LazyLogger writes nothing.""" """The LazyLogger writes nothing."""

View File

@ -32,8 +32,7 @@ class PandasLogger(BaseLogger):
self.data: dict[str, list] = defaultdict(list) self.data: dict[str, list] = defaultdict(list)
self.last_save_step = -1 self.last_save_step = -1
@staticmethod def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Removes invalid data types from the log data.""" """Removes invalid data types from the log data."""
filtered_dict = data.copy() filtered_dict = data.copy()

View File

@ -40,8 +40,8 @@ class TensorboardLogger(BaseLogger):
self.last_save_step = -1 self.last_save_step = -1
self.writer = writer self.writer = writer
@staticmethod
def prepare_dict_for_logging( def prepare_dict_for_logging(
self,
input_dict: dict[str, Any], input_dict: dict[str, Any],
parent_key: str = "", parent_key: str = "",
delimiter: str = "/", delimiter: str = "/",

View File

@ -82,6 +82,10 @@ class WandbLogger(BaseLogger):
self.wandb_run._label(repo="tianshou") # type: ignore self.wandb_run._label(repo="tianshou") # type: ignore
self.tensorboard_logger: TensorboardLogger | None = None self.tensorboard_logger: TensorboardLogger | None = None
def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
assert self.tensorboard_logger is not None
return self.tensorboard_logger.prepare_dict_for_logging(log_data)
def load(self, writer: SummaryWriter) -> None: def load(self, writer: SummaryWriter) -> None:
self.writer = writer self.writer = writer
self.tensorboard_logger = TensorboardLogger( self.tensorboard_logger = TensorboardLogger(
@ -158,5 +162,6 @@ class WandbLogger(BaseLogger):
return epoch, env_step, gradient_step return epoch, env_step, gradient_step
def restore_logged_data(self, log_path: str) -> dict: def restore_logged_data(self, log_path: str) -> dict:
# TODO: add support for restoring logged data from W&B
assert self.tensorboard_logger is not None assert self.tensorboard_logger is not None
return self.tensorboard_logger.restore_logged_data(log_path) return self.tensorboard_logger.restore_logged_data(log_path)