From a7898b15b8ac1520dce213f2dac6e50a7bd97967 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Tue, 12 Mar 2024 15:17:33 +0100 Subject: [PATCH] small fix --- tianshou/utils/logger/base.py | 18 +++++++++++++++--- tianshou/utils/logger/pandas_logger.py | 3 +-- tianshou/utils/logger/tensorboard.py | 2 +- tianshou/utils/logger/wandb.py | 5 +++++ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index a6a92a0..36d321f 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -58,9 +58,15 @@ class BaseLogger(ABC): :param data: the data to write with format ``{key: value}``. """ - @staticmethod - def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: - return log_data + @abstractmethod + def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: + """Prepare the dict for logging by filtering out invalid data types. + + If necessary, reformulate the dict to be compatible with the writer. + + :param log_data: the dict to be prepared for logging. + :return: the prepared dict. + """ def log_train_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during training. @@ -152,6 +158,12 @@ class LazyLogger(BaseLogger): def __init__(self) -> None: super().__init__() + def prepare_dict_for_logging( + self, + data: dict[str, VALID_LOG_VALS_TYPE], + ) -> dict[str, VALID_LOG_VALS_TYPE]: + return data + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """The LazyLogger writes nothing.""" diff --git a/tianshou/utils/logger/pandas_logger.py b/tianshou/utils/logger/pandas_logger.py index f76b5da..3bdbe08 100644 --- a/tianshou/utils/logger/pandas_logger.py +++ b/tianshou/utils/logger/pandas_logger.py @@ -32,8 +32,7 @@ class PandasLogger(BaseLogger): self.data: dict[str, list] = defaultdict(list) self.last_save_step = -1 - @staticmethod - def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]: + def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]: """Removes invalid data types from the log data.""" filtered_dict = data.copy() diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 8b4f33c..dcf0563 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -40,8 +40,8 @@ class TensorboardLogger(BaseLogger): self.last_save_step = -1 self.writer = writer - @staticmethod def prepare_dict_for_logging( + self, input_dict: dict[str, Any], parent_key: str = "", delimiter: str = "/", diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 1ad012a..213124f 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -82,6 +82,10 @@ class WandbLogger(BaseLogger): self.wandb_run._label(repo="tianshou") # type: ignore self.tensorboard_logger: TensorboardLogger | None = None + def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]: + assert self.tensorboard_logger is not None + return self.tensorboard_logger.prepare_dict_for_logging(log_data) + def load(self, writer: SummaryWriter) -> None: self.writer = writer self.tensorboard_logger = TensorboardLogger( @@ -158,5 +162,6 @@ class WandbLogger(BaseLogger): return epoch, env_step, gradient_step def restore_logged_data(self, log_path: str) -> dict: + # TODO: add support for restoring logged data from W&B assert self.tensorboard_logger is not None return self.tensorboard_logger.restore_logged_data(log_path)