From 32cd3b43573714f8a699d2515a46b1a9de16b6de Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Mon, 11 Mar 2024 10:29:17 +0100 Subject: [PATCH] logger updates - introduced logger manager - loggers can reload logged data from disk --- tianshou/highlevel/logger.py | 62 ++++++++++--- tianshou/utils/logger/base.py | 119 +++++++++++++++---------- tianshou/utils/logger/pandas_logger.py | 106 ++++++++++++++++++++++ tianshou/utils/logger/tensorboard.py | 75 +++++++++++++++- 4 files changed, 301 insertions(+), 61 deletions(-) create mode 100644 tianshou/utils/logger/pandas_logger.py diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index a4fe772..71f0b7c 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -5,6 +5,8 @@ from typing import Literal, TypeAlias from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger +from tianshou.utils.logger.base import LoggerManager +from tianshou.utils.logger.pandas_logger import PandasLogger from tianshou.utils.string import ToStringMixin TLogger: TypeAlias = BaseLogger @@ -32,7 +34,7 @@ class LoggerFactory(ToStringMixin, ABC): class LoggerFactoryDefault(LoggerFactory): def __init__( self, - logger_type: Literal["tensorboard", "wandb"] = "tensorboard", + logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard", wandb_project: str | None = None, ): if logger_type == "wandb" and wandb_project is None: @@ -47,18 +49,21 @@ class LoggerFactoryDefault(LoggerFactory): run_id: str | None, config_dict: dict, ) -> TLogger: - writer = SummaryWriter(log_dir) - writer.add_text( - "args", - str( - dict( - log_dir=log_dir, - logger_type=self.logger_type, - wandb_project=self.wandb_project, + if self.logger_type in ["wandb", "tensorboard"]: + writer = SummaryWriter(log_dir) + writer.add_text( + "args", + str( + dict( + log_dir=log_dir, + logger_type=self.logger_type, + wandb_project=self.wandb_project, + ), ), - ), - ) + ) match self.logger_type: + case "pandas": + return PandasLogger(log_dir, exclude_arrays=False) case "wandb": wandb_logger = WandbLogger( save_interval=1, @@ -73,3 +78,38 @@ class LoggerFactoryDefault(LoggerFactory): return TensorboardLogger(writer) case _: raise ValueError(f"Unknown logger type '{self.logger_type}'") + + +class LoggerManagerFactory(LoggerFactory): + def __init__( + self, + logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"], + wandb_project: str | None = None, + ): + self.logger_types = logger_types + self.wandb_project = wandb_project + + self.factories = { + "pandas": LoggerFactoryDefault(logger_type="pandas"), + "wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project), + "tensorboard": LoggerFactoryDefault(logger_type="tensorboard"), + } + + def create_logger( + self, + log_dir: str, + experiment_name: str, + run_id: str | None, + config_dict: dict, + ) -> TLogger: + logger_manager = LoggerManager() + for logger_type in self.logger_types: + logger_manager.loggers.append( + self.factories[logger_type].create_logger( + log_dir, + experiment_name, + run_id, + config_dict, + ) + ) + return logger_manager diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index e6bc09a..7842860 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -37,12 +37,14 @@ class BaseLogger(ABC): test_interval: int = 1, update_interval: int = 1000, info_interval: int = 1, + exclude_arrays: bool = True, ) -> None: super().__init__() self.train_interval = train_interval self.test_interval = test_interval self.update_interval = update_interval self.info_interval = info_interval + self.exclude_arrays = exclude_arrays self.last_log_train_step = -1 self.last_log_test_step = -1 self.last_log_update_step = -1 @@ -58,45 +60,8 @@ class BaseLogger(ABC): """ @staticmethod - def prepare_dict_for_logging( - input_dict: dict[str, Any], - parent_key: str = "", - delimiter: str = "/", - exclude_arrays: bool = True, - ) -> dict[str, VALID_LOG_VALS_TYPE]: - """Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys. - - Filtering is performed with respect to valid logging data types. - - :param input_dict: The nested dictionary to be flattened and filtered. - :param parent_key: The parent key used as a prefix before the input_dict keys. - :param delimiter: The delimiter used to separate the keys. - :param exclude_arrays: Whether to exclude numpy arrays from the output. - :return: A flattened dictionary where the keys are compressed and values are filtered. - """ - result = {} - - def add_to_result( - cur_dict: dict, - prefix: str = "", - ) -> None: - for key, value in cur_dict.items(): - if exclude_arrays and isinstance(value, np.ndarray): - continue - - new_key = prefix + delimiter + key - new_key = new_key.lstrip(delimiter) - - if isinstance(value, dict): - add_to_result( - value, - new_key, - ) - elif isinstance(value, VALID_LOG_VALS): - result[new_key] = value - - add_to_result(input_dict, prefix=parent_key) - return result + def prepare_dict_for_logging(log_data: dict) -> dict: + return log_data def log_train_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during training. @@ -106,8 +71,8 @@ class BaseLogger(ABC): """ # TODO: move interval check to calling method if step - self.last_log_train_step >= self.train_interval: - log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TRAIN.value) - self.write("train/env_step", step, log_data) + log_data = self.prepare_dict_for_logging(log_data) + self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data) self.last_log_train_step = step def log_test_data(self, log_data: dict, step: int) -> None: @@ -118,8 +83,8 @@ class BaseLogger(ABC): """ # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) if step - self.last_log_test_step >= self.test_interval: - log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TEST.value) - self.write(DataScope.TEST.value + "/env_step", step, log_data) + log_data = self.prepare_dict_for_logging(log_data) + self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data) self.last_log_test_step = step def log_update_data(self, log_data: dict, step: int) -> None: @@ -130,8 +95,8 @@ class BaseLogger(ABC): """ # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: - log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.UPDATE.value) - self.write(DataScope.UPDATE.value + "/gradient_step", step, log_data) + log_data = self.prepare_dict_for_logging(log_data) + self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data) self.last_log_update_step = step def log_info_data(self, log_data: dict, step: int) -> None: @@ -143,8 +108,8 @@ class BaseLogger(ABC): if ( step - self.last_log_info_step >= self.info_interval ): # TODO: move interval check to calling method - log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.INFO.value) - self.write(DataScope.INFO.value + "/epoch", step, log_data) + log_data = self.prepare_dict_for_logging(log_data) + self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data) self.last_log_info_step = step @abstractmethod @@ -174,6 +139,11 @@ class BaseLogger(ABC): :return: epoch, env_step, gradient_step. """ + @staticmethod + @abstractmethod + def restore_logged_data(log_path): + """Return the logged data from the logger.""" + class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" @@ -195,3 +165,58 @@ class LazyLogger(BaseLogger): def restore_data(self) -> tuple[int, int, int]: return 0, 0, 0 + + def restore_logged_data(self): + return None + + +class LoggerManager(BaseLogger): + """A container of loggers that holds more than one logger.""" + + def __init__(self) -> None: + super().__init__() + self.loggers = [] + + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: + for logger in self.loggers: + data_copy = data.copy() + logger.write(step_type, step, data_copy) + + def log_train_data(self, log_data: dict, step: int) -> None: + for logger in self.loggers: + logger.log_train_data(log_data, step) + + def log_test_data(self, log_data: dict, step: int) -> None: + for logger in self.loggers: + logger.log_test_data(log_data, step) + + def log_update_data(self, log_data: dict, step: int) -> None: + for logger in self.loggers: + logger.log_update_data(log_data, step) + + def log_info_data(self, log_data: dict, step: int) -> None: + for logger in self.loggers: + logger.log_info_data(log_data, step) + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None, + ) -> None: + for logger in self.loggers: + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) + + def restore_data(self) -> tuple[int, int, int]: + for logger in self.loggers: + epoch, env_step, gradient_step = logger.restore_data() + + self.last_save_step = self.last_log_test_step = epoch + self.last_log_update_step = gradient_step + self.last_log_train_step = env_step + + return epoch, env_step, gradient_step + + def restore_logged_data(self, log_path): + return self.loggers[0].restore_logged_data(log_path) diff --git a/tianshou/utils/logger/pandas_logger.py b/tianshou/utils/logger/pandas_logger.py new file mode 100644 index 0000000..d05d9bd --- /dev/null +++ b/tianshou/utils/logger/pandas_logger.py @@ -0,0 +1,106 @@ +import os +from collections import defaultdict +from functools import partial +from typing import Callable, Any + +import numpy as np +import pandas as pd +from matplotlib.figure import Figure + +from tianshou.utils import BaseLogger, logging +from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE + + +class PandasLogger(BaseLogger): + def __init__( + self, + log_dir: str, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + info_interval: int = 1, + exclude_arrays: bool = True, + ) -> None: + super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays) + self.log_path = log_dir + self.csv_name = os.path.join(self.log_path, "log.csv") + self.pkl_name = os.path.join(self.log_path, "log.pkl") + self.data = defaultdict(list) + self.last_save_step = -1 + + def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: + scope, step_name = step_type.split("/") + data[step_name] = step + self.data[scope].append(data) + + def save_data(self, epoch: int, env_step: int, gradient_step: int, + save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None: + self.last_save_step = epoch + # create and dump a dataframe + for k, v in self.data.items(): + df = pd.DataFrame(v) + df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index") + df.to_pickle(os.path.join(self.log_path, k + "_log.pkl")) + + def restore_data(self) -> tuple[int, int, int]: + for scope in ["train", "test", "update", "info"]: + try: + self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values())) + except FileNotFoundError: + logging.warning(f"Failed to restore {scope} data") + + try: # epoch / gradient_step + epoch = self.data["info"][-1]["epoch"] + self.last_save_step = self.last_log_test_step = epoch + except (KeyError, IndexError): + epoch = 0 + try: + gradient_step = self.data["update"][-1]["gradient_step"] + self.last_log_update_step = gradient_step + except (KeyError, IndexError): + gradient_step = 0 + try: # offline trainer doesn't have env_step + env_step = self.data["train"][-1]["env_step"] + self.last_log_train_step = env_step + except (KeyError, IndexError): + env_step = 0 + + return epoch, env_step, gradient_step + + @staticmethod + def restore_logged_data(log_path): + data = {} + + def merge_dicts(dict_list): + result = defaultdict(list) + for d in dict_list: + for key, value in d.items(): + if isinstance(value, dict): + result[key] = merge_dicts([result.get(key, {}), value]) + elif isinstance(value, list): + result[key] = result.get(key, []) + value + else: + result[key].append(value) + return result + + for scope in ["train", "test", "update", "info"]: + try: + dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values()) + data[scope] = merge_dicts(dict_list) + except FileNotFoundError: + logging.warning(f"Failed to restore {scope} data") + return data + + def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]: + """Filter out matplotlib figures from the data.""" + filtered_dict = data.copy() + + def filter_dict(d): + for key, value in d.items(): + if isinstance(value, dict): + filter_dict(value) + elif isinstance(value, Figure): + filtered_dict.pop(key) + + filter_dict(data) + return filtered_dict diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 2a26963..5c2a63f 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -1,10 +1,13 @@ +from collections import defaultdict from collections.abc import Callable from typing import Any +import numpy as np +from matplotlib.figure import Figure from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger +from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS from tianshou.utils.warning import deprecation @@ -38,9 +41,57 @@ class TensorboardLogger(BaseLogger): self.last_save_step = -1 self.writer = writer - def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: + @staticmethod + def prepare_dict_for_logging( + input_dict: dict[str, Any], + parent_key: str = "", + delimiter: str = "/", + exclude_arrays: bool = True, + ) -> dict[str, VALID_LOG_VALS_TYPE]: + """Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys. + + Filtering is performed with respect to valid logging data types. + + :param input_dict: The nested dictionary to be flattened and filtered. + :param parent_key: The parent key used as a prefix before the input_dict keys. + :param delimiter: The delimiter used to separate the keys. + :param exclude_arrays: Whether to exclude numpy arrays from the output. + :return: A flattened dictionary where the keys are compressed and values are filtered. + """ + result = {} + + def add_to_result( + cur_dict: dict, + prefix: str = "", + ) -> None: + for key, value in cur_dict.items(): + if exclude_arrays and isinstance(value, np.ndarray): + continue + + new_key = prefix + delimiter + key + new_key = new_key.lstrip(delimiter) + + if isinstance(value, dict): + add_to_result( + value, + new_key, + ) + elif isinstance(value, VALID_LOG_VALS): + result[new_key] = value + + add_to_result(input_dict, prefix=parent_key) + return result + + def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: + scope, step_name = step_type.split("/") for k, v in data.items(): - self.writer.add_scalar(k, v, global_step=step) + scope_key = '/'.join([scope, k]) + if isinstance(v, np.ndarray): + self.writer.add_histogram(scope_key, v, global_step=step, bins="auto") + elif isinstance(v, Figure): + self.writer.add_figure(scope_key, v, global_step=step) + else: + self.writer.add_scalar(scope_key, v, global_step=step) if self.write_flush: # issue 580 self.writer.flush() # issue #482 @@ -81,6 +132,24 @@ class TensorboardLogger(BaseLogger): return epoch, env_step, gradient_step + @staticmethod + def restore_logged_data(log_path): + ea = event_accumulator.EventAccumulator(log_path) + ea.Reload() + + def add_to_dict(dictionary, keys, value): + current_dict = dictionary + for key in keys[:-1]: + current_dict = current_dict.setdefault(key, {}) + current_dict[keys[-1]] = value + + data = {} + for key in ea.scalars.Keys(): + split_keys = key.split('/') + add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)])) + + return data + class BasicLogger(TensorboardLogger): """BasicLogger has changed its name to TensorboardLogger in #427.