logger updates

This commit is contained in:
Maximilian Huettenrauch 2024-03-12 11:31:41 +01:00
parent 32cd3b4357
commit 734119ec00
3 changed files with 5 additions and 4 deletions

View File

@ -131,7 +131,7 @@ class BaseLogger(ABC):
@abstractmethod @abstractmethod
def restore_data(self) -> tuple[int, int, int]: def restore_data(self) -> tuple[int, int, int]:
"""Return the metadata from existing log. """Restore internal data if present and return the metadata from existing log for continuation of training.
If it finds nothing or an error occurs during the recover process, it will If it finds nothing or an error occurs during the recover process, it will
return the default parameters. return the default parameters.
@ -142,7 +142,7 @@ class BaseLogger(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def restore_logged_data(log_path): def restore_logged_data(log_path):
"""Return the logged data from the logger.""" """Load the logged data from dist for post-processing."""
class LazyLogger(BaseLogger): class LazyLogger(BaseLogger):

View File

@ -71,9 +71,9 @@ class PandasLogger(BaseLogger):
def restore_logged_data(log_path): def restore_logged_data(log_path):
data = {} data = {}
def merge_dicts(dict_list): def merge_dicts(list_of_dicts):
result = defaultdict(list) result = defaultdict(list)
for d in dict_list: for d in list_of_dicts:
for key, value in d.items(): for key, value in d.items():
if isinstance(value, dict): if isinstance(value, dict):
result[key] = merge_dicts([result.get(key, {}), value]) result[key] = merge_dicts([result.get(key, {}), value])

View File

@ -84,6 +84,7 @@ class TensorboardLogger(BaseLogger):
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/") scope, step_name = step_type.split("/")
self.writer.add_scalar(step_type, step, global_step=step)
for k, v in data.items(): for k, v in data.items():
scope_key = '/'.join([scope, k]) scope_key = '/'.join([scope, k])
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):