logger updates

- introduced logger manager
- loggers can reload logged data from disk
This commit is contained in:
Maximilian Huettenrauch 2024-03-11 10:29:17 +01:00
parent 95cbfe6cdf
commit 32cd3b4357
4 changed files with 301 additions and 61 deletions

View File

@ -5,6 +5,8 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.logger.base import LoggerManager
from tianshou.utils.logger.pandas_logger import PandasLogger
from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = BaseLogger
@ -32,7 +34,7 @@ class LoggerFactory(ToStringMixin, ABC):
class LoggerFactoryDefault(LoggerFactory):
def __init__(
self,
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard",
wandb_project: str | None = None,
):
if logger_type == "wandb" and wandb_project is None:
@ -47,18 +49,21 @@ class LoggerFactoryDefault(LoggerFactory):
run_id: str | None,
config_dict: dict,
) -> TLogger:
writer = SummaryWriter(log_dir)
writer.add_text(
"args",
str(
dict(
log_dir=log_dir,
logger_type=self.logger_type,
wandb_project=self.wandb_project,
if self.logger_type in ["wandb", "tensorboard"]:
writer = SummaryWriter(log_dir)
writer.add_text(
"args",
str(
dict(
log_dir=log_dir,
logger_type=self.logger_type,
wandb_project=self.wandb_project,
),
),
),
)
)
match self.logger_type:
case "pandas":
return PandasLogger(log_dir, exclude_arrays=False)
case "wandb":
wandb_logger = WandbLogger(
save_interval=1,
@ -73,3 +78,38 @@ class LoggerFactoryDefault(LoggerFactory):
return TensorboardLogger(writer)
case _:
raise ValueError(f"Unknown logger type '{self.logger_type}'")
class LoggerManagerFactory(LoggerFactory):
def __init__(
self,
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
wandb_project: str | None = None,
):
self.logger_types = logger_types
self.wandb_project = wandb_project
self.factories = {
"pandas": LoggerFactoryDefault(logger_type="pandas"),
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
}
def create_logger(
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
logger_manager = LoggerManager()
for logger_type in self.logger_types:
logger_manager.loggers.append(
self.factories[logger_type].create_logger(
log_dir,
experiment_name,
run_id,
config_dict,
)
)
return logger_manager

View File

@ -37,12 +37,14 @@ class BaseLogger(ABC):
test_interval: int = 1,
update_interval: int = 1000,
info_interval: int = 1,
exclude_arrays: bool = True,
) -> None:
super().__init__()
self.train_interval = train_interval
self.test_interval = test_interval
self.update_interval = update_interval
self.info_interval = info_interval
self.exclude_arrays = exclude_arrays
self.last_log_train_step = -1
self.last_log_test_step = -1
self.last_log_update_step = -1
@ -58,45 +60,8 @@ class BaseLogger(ABC):
"""
@staticmethod
def prepare_dict_for_logging(
input_dict: dict[str, Any],
parent_key: str = "",
delimiter: str = "/",
exclude_arrays: bool = True,
) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
Filtering is performed with respect to valid logging data types.
:param input_dict: The nested dictionary to be flattened and filtered.
:param parent_key: The parent key used as a prefix before the input_dict keys.
:param delimiter: The delimiter used to separate the keys.
:param exclude_arrays: Whether to exclude numpy arrays from the output.
:return: A flattened dictionary where the keys are compressed and values are filtered.
"""
result = {}
def add_to_result(
cur_dict: dict,
prefix: str = "",
) -> None:
for key, value in cur_dict.items():
if exclude_arrays and isinstance(value, np.ndarray):
continue
new_key = prefix + delimiter + key
new_key = new_key.lstrip(delimiter)
if isinstance(value, dict):
add_to_result(
value,
new_key,
)
elif isinstance(value, VALID_LOG_VALS):
result[new_key] = value
add_to_result(input_dict, prefix=parent_key)
return result
def prepare_dict_for_logging(log_data: dict) -> dict:
return log_data
def log_train_data(self, log_data: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
@ -106,8 +71,8 @@ class BaseLogger(ABC):
"""
# TODO: move interval check to calling method
if step - self.last_log_train_step >= self.train_interval:
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TRAIN.value)
self.write("train/env_step", step, log_data)
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data)
self.last_log_train_step = step
def log_test_data(self, log_data: dict, step: int) -> None:
@ -118,8 +83,8 @@ class BaseLogger(ABC):
"""
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
if step - self.last_log_test_step >= self.test_interval:
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TEST.value)
self.write(DataScope.TEST.value + "/env_step", step, log_data)
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data)
self.last_log_test_step = step
def log_update_data(self, log_data: dict, step: int) -> None:
@ -130,8 +95,8 @@ class BaseLogger(ABC):
"""
# TODO: move interval check to calling method
if step - self.last_log_update_step >= self.update_interval:
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.UPDATE.value)
self.write(DataScope.UPDATE.value + "/gradient_step", step, log_data)
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data)
self.last_log_update_step = step
def log_info_data(self, log_data: dict, step: int) -> None:
@ -143,8 +108,8 @@ class BaseLogger(ABC):
if (
step - self.last_log_info_step >= self.info_interval
): # TODO: move interval check to calling method
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.INFO.value)
self.write(DataScope.INFO.value + "/epoch", step, log_data)
log_data = self.prepare_dict_for_logging(log_data)
self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data)
self.last_log_info_step = step
@abstractmethod
@ -174,6 +139,11 @@ class BaseLogger(ABC):
:return: epoch, env_step, gradient_step.
"""
@staticmethod
@abstractmethod
def restore_logged_data(log_path):
"""Return the logged data from the logger."""
class LazyLogger(BaseLogger):
"""A logger that does nothing. Used as the placeholder in trainer."""
@ -195,3 +165,58 @@ class LazyLogger(BaseLogger):
def restore_data(self) -> tuple[int, int, int]:
return 0, 0, 0
def restore_logged_data(self):
return None
class LoggerManager(BaseLogger):
"""A container of loggers that holds more than one logger."""
def __init__(self) -> None:
super().__init__()
self.loggers = []
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
for logger in self.loggers:
data_copy = data.copy()
logger.write(step_type, step, data_copy)
def log_train_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_train_data(log_data, step)
def log_test_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_test_data(log_data, step)
def log_update_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_update_data(log_data, step)
def log_info_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_info_data(log_data, step)
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
for logger in self.loggers:
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
def restore_data(self) -> tuple[int, int, int]:
for logger in self.loggers:
epoch, env_step, gradient_step = logger.restore_data()
self.last_save_step = self.last_log_test_step = epoch
self.last_log_update_step = gradient_step
self.last_log_train_step = env_step
return epoch, env_step, gradient_step
def restore_logged_data(self, log_path):
return self.loggers[0].restore_logged_data(log_path)

View File

@ -0,0 +1,106 @@
import os
from collections import defaultdict
from functools import partial
from typing import Callable, Any
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from tianshou.utils import BaseLogger, logging
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
class PandasLogger(BaseLogger):
def __init__(
self,
log_dir: str,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
info_interval: int = 1,
exclude_arrays: bool = True,
) -> None:
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
self.log_path = log_dir
self.csv_name = os.path.join(self.log_path, "log.csv")
self.pkl_name = os.path.join(self.log_path, "log.pkl")
self.data = defaultdict(list)
self.last_save_step = -1
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/")
data[step_name] = step
self.data[scope].append(data)
def save_data(self, epoch: int, env_step: int, gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
self.last_save_step = epoch
# create and dump a dataframe
for k, v in self.data.items():
df = pd.DataFrame(v)
df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index")
df.to_pickle(os.path.join(self.log_path, k + "_log.pkl"))
def restore_data(self) -> tuple[int, int, int]:
for scope in ["train", "test", "update", "info"]:
try:
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
try: # epoch / gradient_step
epoch = self.data["info"][-1]["epoch"]
self.last_save_step = self.last_log_test_step = epoch
except (KeyError, IndexError):
epoch = 0
try:
gradient_step = self.data["update"][-1]["gradient_step"]
self.last_log_update_step = gradient_step
except (KeyError, IndexError):
gradient_step = 0
try: # offline trainer doesn't have env_step
env_step = self.data["train"][-1]["env_step"]
self.last_log_train_step = env_step
except (KeyError, IndexError):
env_step = 0
return epoch, env_step, gradient_step
@staticmethod
def restore_logged_data(log_path):
data = {}
def merge_dicts(dict_list):
result = defaultdict(list)
for d in dict_list:
for key, value in d.items():
if isinstance(value, dict):
result[key] = merge_dicts([result.get(key, {}), value])
elif isinstance(value, list):
result[key] = result.get(key, []) + value
else:
result[key].append(value)
return result
for scope in ["train", "test", "update", "info"]:
try:
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
data[scope] = merge_dicts(dict_list)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
return data
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Filter out matplotlib figures from the data."""
filtered_dict = data.copy()
def filter_dict(d):
for key, value in d.items():
if isinstance(value, dict):
filter_dict(value)
elif isinstance(value, Figure):
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict

View File

@ -1,10 +1,13 @@
from collections import defaultdict
from collections.abc import Callable
from typing import Any
import numpy as np
from matplotlib.figure import Figure
from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS
from tianshou.utils.warning import deprecation
@ -38,9 +41,57 @@ class TensorboardLogger(BaseLogger):
self.last_save_step = -1
self.writer = writer
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
@staticmethod
def prepare_dict_for_logging(
input_dict: dict[str, Any],
parent_key: str = "",
delimiter: str = "/",
exclude_arrays: bool = True,
) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
Filtering is performed with respect to valid logging data types.
:param input_dict: The nested dictionary to be flattened and filtered.
:param parent_key: The parent key used as a prefix before the input_dict keys.
:param delimiter: The delimiter used to separate the keys.
:param exclude_arrays: Whether to exclude numpy arrays from the output.
:return: A flattened dictionary where the keys are compressed and values are filtered.
"""
result = {}
def add_to_result(
cur_dict: dict,
prefix: str = "",
) -> None:
for key, value in cur_dict.items():
if exclude_arrays and isinstance(value, np.ndarray):
continue
new_key = prefix + delimiter + key
new_key = new_key.lstrip(delimiter)
if isinstance(value, dict):
add_to_result(
value,
new_key,
)
elif isinstance(value, VALID_LOG_VALS):
result[new_key] = value
add_to_result(input_dict, prefix=parent_key)
return result
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/")
for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step)
scope_key = '/'.join([scope, k])
if isinstance(v, np.ndarray):
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
elif isinstance(v, Figure):
self.writer.add_figure(scope_key, v, global_step=step)
else:
self.writer.add_scalar(scope_key, v, global_step=step)
if self.write_flush: # issue 580
self.writer.flush() # issue #482
@ -81,6 +132,24 @@ class TensorboardLogger(BaseLogger):
return epoch, env_step, gradient_step
@staticmethod
def restore_logged_data(log_path):
ea = event_accumulator.EventAccumulator(log_path)
ea.Reload()
def add_to_dict(dictionary, keys, value):
current_dict = dictionary
for key in keys[:-1]:
current_dict = current_dict.setdefault(key, {})
current_dict[keys[-1]] = value
data = {}
for key in ea.scalars.Keys():
split_keys = key.split('/')
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
return data
class BasicLogger(TensorboardLogger):
"""BasicLogger has changed its name to TensorboardLogger in #427.