small fix
This commit is contained in:
parent
d9a612a997
commit
a7898b15b8
@ -58,9 +58,15 @@ class BaseLogger(ABC):
|
|||||||
:param data: the data to write with format ``{key: value}``.
|
:param data: the data to write with format ``{key: value}``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@abstractmethod
|
||||||
def prepare_dict_for_logging(log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
|
def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
return log_data
|
"""Prepare the dict for logging by filtering out invalid data types.
|
||||||
|
|
||||||
|
If necessary, reformulate the dict to be compatible with the writer.
|
||||||
|
|
||||||
|
:param log_data: the dict to be prepared for logging.
|
||||||
|
:return: the prepared dict.
|
||||||
|
"""
|
||||||
|
|
||||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||||
"""Use writer to log statistics generated during training.
|
"""Use writer to log statistics generated during training.
|
||||||
@ -152,6 +158,12 @@ class LazyLogger(BaseLogger):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def prepare_dict_for_logging(
|
||||||
|
self,
|
||||||
|
data: dict[str, VALID_LOG_VALS_TYPE],
|
||||||
|
) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
|
return data
|
||||||
|
|
||||||
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
||||||
"""The LazyLogger writes nothing."""
|
"""The LazyLogger writes nothing."""
|
||||||
|
|
||||||
|
@ -32,8 +32,7 @@ class PandasLogger(BaseLogger):
|
|||||||
self.data: dict[str, list] = defaultdict(list)
|
self.data: dict[str, list] = defaultdict(list)
|
||||||
self.last_save_step = -1
|
self.last_save_step = -1
|
||||||
|
|
||||||
@staticmethod
|
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
|
||||||
"""Removes invalid data types from the log data."""
|
"""Removes invalid data types from the log data."""
|
||||||
filtered_dict = data.copy()
|
filtered_dict = data.copy()
|
||||||
|
|
||||||
|
@ -40,8 +40,8 @@ class TensorboardLogger(BaseLogger):
|
|||||||
self.last_save_step = -1
|
self.last_save_step = -1
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def prepare_dict_for_logging(
|
def prepare_dict_for_logging(
|
||||||
|
self,
|
||||||
input_dict: dict[str, Any],
|
input_dict: dict[str, Any],
|
||||||
parent_key: str = "",
|
parent_key: str = "",
|
||||||
delimiter: str = "/",
|
delimiter: str = "/",
|
||||||
|
@ -82,6 +82,10 @@ class WandbLogger(BaseLogger):
|
|||||||
self.wandb_run._label(repo="tianshou") # type: ignore
|
self.wandb_run._label(repo="tianshou") # type: ignore
|
||||||
self.tensorboard_logger: TensorboardLogger | None = None
|
self.tensorboard_logger: TensorboardLogger | None = None
|
||||||
|
|
||||||
|
def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
|
assert self.tensorboard_logger is not None
|
||||||
|
return self.tensorboard_logger.prepare_dict_for_logging(log_data)
|
||||||
|
|
||||||
def load(self, writer: SummaryWriter) -> None:
|
def load(self, writer: SummaryWriter) -> None:
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
self.tensorboard_logger = TensorboardLogger(
|
self.tensorboard_logger = TensorboardLogger(
|
||||||
@ -158,5 +162,6 @@ class WandbLogger(BaseLogger):
|
|||||||
return epoch, env_step, gradient_step
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
def restore_logged_data(self, log_path: str) -> dict:
|
def restore_logged_data(self, log_path: str) -> dict:
|
||||||
|
# TODO: add support for restoring logged data from W&B
|
||||||
assert self.tensorboard_logger is not None
|
assert self.tensorboard_logger is not None
|
||||||
return self.tensorboard_logger.restore_logged_data(log_path)
|
return self.tensorboard_logger.restore_logged_data(log_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user