From 734119ec004307d534b94f28a5eb5fa5f15d9006 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Tue, 12 Mar 2024 11:31:41 +0100 Subject: [PATCH] logger updates --- tianshou/utils/logger/base.py | 4 ++-- tianshou/utils/logger/pandas_logger.py | 4 ++-- tianshou/utils/logger/tensorboard.py | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 7842860..afe2259 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -131,7 +131,7 @@ class BaseLogger(ABC): @abstractmethod def restore_data(self) -> tuple[int, int, int]: - """Return the metadata from existing log. + """Restore internal data if present and return the metadata from existing log for continuation of training. If it finds nothing or an error occurs during the recover process, it will return the default parameters. @@ -142,7 +142,7 @@ class BaseLogger(ABC): @staticmethod @abstractmethod def restore_logged_data(log_path): - """Return the logged data from the logger.""" + """Load the logged data from dist for post-processing.""" class LazyLogger(BaseLogger): diff --git a/tianshou/utils/logger/pandas_logger.py b/tianshou/utils/logger/pandas_logger.py index d05d9bd..882cbf3 100644 --- a/tianshou/utils/logger/pandas_logger.py +++ b/tianshou/utils/logger/pandas_logger.py @@ -71,9 +71,9 @@ class PandasLogger(BaseLogger): def restore_logged_data(log_path): data = {} - def merge_dicts(dict_list): + def merge_dicts(list_of_dicts): result = defaultdict(list) - for d in dict_list: + for d in list_of_dicts: for key, value in d.items(): if isinstance(value, dict): result[key] = merge_dicts([result.get(key, {}), value]) diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 5c2a63f..555a0c7 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -84,6 +84,7 @@ class TensorboardLogger(BaseLogger): def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: scope, step_name = step_type.split("/") + self.writer.add_scalar(step_type, step, global_step=step) for k, v in data.items(): scope_key = '/'.join([scope, k]) if isinstance(v, np.ndarray):