logger updates
This commit is contained in:
parent
32cd3b4357
commit
734119ec00
@ -131,7 +131,7 @@ class BaseLogger(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def restore_data(self) -> tuple[int, int, int]:
|
def restore_data(self) -> tuple[int, int, int]:
|
||||||
"""Return the metadata from existing log.
|
"""Restore internal data if present and return the metadata from existing log for continuation of training.
|
||||||
|
|
||||||
If it finds nothing or an error occurs during the recover process, it will
|
If it finds nothing or an error occurs during the recover process, it will
|
||||||
return the default parameters.
|
return the default parameters.
|
||||||
@ -142,7 +142,7 @@ class BaseLogger(ABC):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def restore_logged_data(log_path):
|
def restore_logged_data(log_path):
|
||||||
"""Return the logged data from the logger."""
|
"""Load the logged data from dist for post-processing."""
|
||||||
|
|
||||||
|
|
||||||
class LazyLogger(BaseLogger):
|
class LazyLogger(BaseLogger):
|
||||||
|
@ -71,9 +71,9 @@ class PandasLogger(BaseLogger):
|
|||||||
def restore_logged_data(log_path):
|
def restore_logged_data(log_path):
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
def merge_dicts(dict_list):
|
def merge_dicts(list_of_dicts):
|
||||||
result = defaultdict(list)
|
result = defaultdict(list)
|
||||||
for d in dict_list:
|
for d in list_of_dicts:
|
||||||
for key, value in d.items():
|
for key, value in d.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
result[key] = merge_dicts([result.get(key, {}), value])
|
result[key] = merge_dicts([result.get(key, {}), value])
|
||||||
|
@ -84,6 +84,7 @@ class TensorboardLogger(BaseLogger):
|
|||||||
|
|
||||||
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
||||||
scope, step_name = step_type.split("/")
|
scope, step_name = step_type.split("/")
|
||||||
|
self.writer.add_scalar(step_type, step, global_step=step)
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
scope_key = '/'.join([scope, k])
|
scope_key = '/'.join([scope, k])
|
||||||
if isinstance(v, np.ndarray):
|
if isinstance(v, np.ndarray):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user