logger updates

- introduced logger manager
- loggers can reload logged data from disk
This commit is contained in:
Maximilian Huettenrauch 2024-03-11 10:29:17 +01:00
parent 95cbfe6cdf
commit 32cd3b4357
4 changed files with 301 additions and 61 deletions

View File

@ -5,6 +5,8 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.logger.base import LoggerManager
from tianshou.utils.logger.pandas_logger import PandasLogger
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = BaseLogger TLogger: TypeAlias = BaseLogger
@ -32,7 +34,7 @@ class LoggerFactory(ToStringMixin, ABC):
class LoggerFactoryDefault(LoggerFactory): class LoggerFactoryDefault(LoggerFactory):
def __init__( def __init__(
self, self,
logger_type: Literal["tensorboard", "wandb"] = "tensorboard", logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard",
wandb_project: str | None = None, wandb_project: str | None = None,
): ):
if logger_type == "wandb" and wandb_project is None: if logger_type == "wandb" and wandb_project is None:
@ -47,18 +49,21 @@ class LoggerFactoryDefault(LoggerFactory):
run_id: str | None, run_id: str | None,
config_dict: dict, config_dict: dict,
) -> TLogger: ) -> TLogger:
writer = SummaryWriter(log_dir) if self.logger_type in ["wandb", "tensorboard"]:
writer.add_text( writer = SummaryWriter(log_dir)
"args", writer.add_text(
str( "args",
dict( str(
log_dir=log_dir, dict(
logger_type=self.logger_type, log_dir=log_dir,
wandb_project=self.wandb_project, logger_type=self.logger_type,
wandb_project=self.wandb_project,
),
), ),
), )
)
match self.logger_type: match self.logger_type:
case "pandas":
return PandasLogger(log_dir, exclude_arrays=False)
case "wandb": case "wandb":
wandb_logger = WandbLogger( wandb_logger = WandbLogger(
save_interval=1, save_interval=1,
@ -73,3 +78,38 @@ class LoggerFactoryDefault(LoggerFactory):
return TensorboardLogger(writer) return TensorboardLogger(writer)
case _: case _:
raise ValueError(f"Unknown logger type '{self.logger_type}'") raise ValueError(f"Unknown logger type '{self.logger_type}'")
class LoggerManagerFactory(LoggerFactory):
def __init__(
self,
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
wandb_project: str | None = None,
):
self.logger_types = logger_types
self.wandb_project = wandb_project
self.factories = {
"pandas": LoggerFactoryDefault(logger_type="pandas"),
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
}
def create_logger(
self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger:
logger_manager = LoggerManager()
for logger_type in self.logger_types:
logger_manager.loggers.append(
self.factories[logger_type].create_logger(
log_dir,
experiment_name,
run_id,
config_dict,
)
)
return logger_manager

View File

@ -37,12 +37,14 @@ class BaseLogger(ABC):
test_interval: int = 1, test_interval: int = 1,
update_interval: int = 1000, update_interval: int = 1000,
info_interval: int = 1, info_interval: int = 1,
exclude_arrays: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.train_interval = train_interval self.train_interval = train_interval
self.test_interval = test_interval self.test_interval = test_interval
self.update_interval = update_interval self.update_interval = update_interval
self.info_interval = info_interval self.info_interval = info_interval
self.exclude_arrays = exclude_arrays
self.last_log_train_step = -1 self.last_log_train_step = -1
self.last_log_test_step = -1 self.last_log_test_step = -1
self.last_log_update_step = -1 self.last_log_update_step = -1
@ -58,45 +60,8 @@ class BaseLogger(ABC):
""" """
@staticmethod @staticmethod
def prepare_dict_for_logging( def prepare_dict_for_logging(log_data: dict) -> dict:
input_dict: dict[str, Any], return log_data
parent_key: str = "",
delimiter: str = "/",
exclude_arrays: bool = True,
) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
Filtering is performed with respect to valid logging data types.
:param input_dict: The nested dictionary to be flattened and filtered.
:param parent_key: The parent key used as a prefix before the input_dict keys.
:param delimiter: The delimiter used to separate the keys.
:param exclude_arrays: Whether to exclude numpy arrays from the output.
:return: A flattened dictionary where the keys are compressed and values are filtered.
"""
result = {}
def add_to_result(
cur_dict: dict,
prefix: str = "",
) -> None:
for key, value in cur_dict.items():
if exclude_arrays and isinstance(value, np.ndarray):
continue
new_key = prefix + delimiter + key
new_key = new_key.lstrip(delimiter)
if isinstance(value, dict):
add_to_result(
value,
new_key,
)
elif isinstance(value, VALID_LOG_VALS):
result[new_key] = value
add_to_result(input_dict, prefix=parent_key)
return result
def log_train_data(self, log_data: dict, step: int) -> None: def log_train_data(self, log_data: dict, step: int) -> None:
"""Use writer to log statistics generated during training. """Use writer to log statistics generated during training.
@ -106,8 +71,8 @@ class BaseLogger(ABC):
""" """
# TODO: move interval check to calling method # TODO: move interval check to calling method
if step - self.last_log_train_step >= self.train_interval: if step - self.last_log_train_step >= self.train_interval:
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TRAIN.value) log_data = self.prepare_dict_for_logging(log_data)
self.write("train/env_step", step, log_data) self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data)
self.last_log_train_step = step self.last_log_train_step = step
def log_test_data(self, log_data: dict, step: int) -> None: def log_test_data(self, log_data: dict, step: int) -> None:
@ -118,8 +83,8 @@ class BaseLogger(ABC):
""" """
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
if step - self.last_log_test_step >= self.test_interval: if step - self.last_log_test_step >= self.test_interval:
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TEST.value) log_data = self.prepare_dict_for_logging(log_data)
self.write(DataScope.TEST.value + "/env_step", step, log_data) self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data)
self.last_log_test_step = step self.last_log_test_step = step
def log_update_data(self, log_data: dict, step: int) -> None: def log_update_data(self, log_data: dict, step: int) -> None:
@ -130,8 +95,8 @@ class BaseLogger(ABC):
""" """
# TODO: move interval check to calling method # TODO: move interval check to calling method
if step - self.last_log_update_step >= self.update_interval: if step - self.last_log_update_step >= self.update_interval:
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.UPDATE.value) log_data = self.prepare_dict_for_logging(log_data)
self.write(DataScope.UPDATE.value + "/gradient_step", step, log_data) self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data)
self.last_log_update_step = step self.last_log_update_step = step
def log_info_data(self, log_data: dict, step: int) -> None: def log_info_data(self, log_data: dict, step: int) -> None:
@ -143,8 +108,8 @@ class BaseLogger(ABC):
if ( if (
step - self.last_log_info_step >= self.info_interval step - self.last_log_info_step >= self.info_interval
): # TODO: move interval check to calling method ): # TODO: move interval check to calling method
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.INFO.value) log_data = self.prepare_dict_for_logging(log_data)
self.write(DataScope.INFO.value + "/epoch", step, log_data) self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data)
self.last_log_info_step = step self.last_log_info_step = step
@abstractmethod @abstractmethod
@ -174,6 +139,11 @@ class BaseLogger(ABC):
:return: epoch, env_step, gradient_step. :return: epoch, env_step, gradient_step.
""" """
@staticmethod
@abstractmethod
def restore_logged_data(log_path):
"""Return the logged data from the logger."""
class LazyLogger(BaseLogger): class LazyLogger(BaseLogger):
"""A logger that does nothing. Used as the placeholder in trainer.""" """A logger that does nothing. Used as the placeholder in trainer."""
@ -195,3 +165,58 @@ class LazyLogger(BaseLogger):
def restore_data(self) -> tuple[int, int, int]: def restore_data(self) -> tuple[int, int, int]:
return 0, 0, 0 return 0, 0, 0
def restore_logged_data(self):
return None
class LoggerManager(BaseLogger):
"""A container of loggers that holds more than one logger."""
def __init__(self) -> None:
super().__init__()
self.loggers = []
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
for logger in self.loggers:
data_copy = data.copy()
logger.write(step_type, step, data_copy)
def log_train_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_train_data(log_data, step)
def log_test_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_test_data(log_data, step)
def log_update_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_update_data(log_data, step)
def log_info_data(self, log_data: dict, step: int) -> None:
for logger in self.loggers:
logger.log_info_data(log_data, step)
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
for logger in self.loggers:
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
def restore_data(self) -> tuple[int, int, int]:
for logger in self.loggers:
epoch, env_step, gradient_step = logger.restore_data()
self.last_save_step = self.last_log_test_step = epoch
self.last_log_update_step = gradient_step
self.last_log_train_step = env_step
return epoch, env_step, gradient_step
def restore_logged_data(self, log_path):
return self.loggers[0].restore_logged_data(log_path)

View File

@ -0,0 +1,106 @@
import os
from collections import defaultdict
from functools import partial
from typing import Callable, Any
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from tianshou.utils import BaseLogger, logging
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
class PandasLogger(BaseLogger):
def __init__(
self,
log_dir: str,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
info_interval: int = 1,
exclude_arrays: bool = True,
) -> None:
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
self.log_path = log_dir
self.csv_name = os.path.join(self.log_path, "log.csv")
self.pkl_name = os.path.join(self.log_path, "log.pkl")
self.data = defaultdict(list)
self.last_save_step = -1
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/")
data[step_name] = step
self.data[scope].append(data)
def save_data(self, epoch: int, env_step: int, gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
self.last_save_step = epoch
# create and dump a dataframe
for k, v in self.data.items():
df = pd.DataFrame(v)
df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index")
df.to_pickle(os.path.join(self.log_path, k + "_log.pkl"))
def restore_data(self) -> tuple[int, int, int]:
for scope in ["train", "test", "update", "info"]:
try:
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
try: # epoch / gradient_step
epoch = self.data["info"][-1]["epoch"]
self.last_save_step = self.last_log_test_step = epoch
except (KeyError, IndexError):
epoch = 0
try:
gradient_step = self.data["update"][-1]["gradient_step"]
self.last_log_update_step = gradient_step
except (KeyError, IndexError):
gradient_step = 0
try: # offline trainer doesn't have env_step
env_step = self.data["train"][-1]["env_step"]
self.last_log_train_step = env_step
except (KeyError, IndexError):
env_step = 0
return epoch, env_step, gradient_step
@staticmethod
def restore_logged_data(log_path):
data = {}
def merge_dicts(dict_list):
result = defaultdict(list)
for d in dict_list:
for key, value in d.items():
if isinstance(value, dict):
result[key] = merge_dicts([result.get(key, {}), value])
elif isinstance(value, list):
result[key] = result.get(key, []) + value
else:
result[key].append(value)
return result
for scope in ["train", "test", "update", "info"]:
try:
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
data[scope] = merge_dicts(dict_list)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
return data
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Filter out matplotlib figures from the data."""
filtered_dict = data.copy()
def filter_dict(d):
for key, value in d.items():
if isinstance(value, dict):
filter_dict(value)
elif isinstance(value, Figure):
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict

View File

@ -1,10 +1,13 @@
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
import numpy as np
from matplotlib.figure import Figure
from tensorboard.backend.event_processing import event_accumulator from tensorboard.backend.event_processing import event_accumulator
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS
from tianshou.utils.warning import deprecation from tianshou.utils.warning import deprecation
@ -38,9 +41,57 @@ class TensorboardLogger(BaseLogger):
self.last_save_step = -1 self.last_save_step = -1
self.writer = writer self.writer = writer
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: @staticmethod
def prepare_dict_for_logging(
input_dict: dict[str, Any],
parent_key: str = "",
delimiter: str = "/",
exclude_arrays: bool = True,
) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
Filtering is performed with respect to valid logging data types.
:param input_dict: The nested dictionary to be flattened and filtered.
:param parent_key: The parent key used as a prefix before the input_dict keys.
:param delimiter: The delimiter used to separate the keys.
:param exclude_arrays: Whether to exclude numpy arrays from the output.
:return: A flattened dictionary where the keys are compressed and values are filtered.
"""
result = {}
def add_to_result(
cur_dict: dict,
prefix: str = "",
) -> None:
for key, value in cur_dict.items():
if exclude_arrays and isinstance(value, np.ndarray):
continue
new_key = prefix + delimiter + key
new_key = new_key.lstrip(delimiter)
if isinstance(value, dict):
add_to_result(
value,
new_key,
)
elif isinstance(value, VALID_LOG_VALS):
result[new_key] = value
add_to_result(input_dict, prefix=parent_key)
return result
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/")
for k, v in data.items(): for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step) scope_key = '/'.join([scope, k])
if isinstance(v, np.ndarray):
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
elif isinstance(v, Figure):
self.writer.add_figure(scope_key, v, global_step=step)
else:
self.writer.add_scalar(scope_key, v, global_step=step)
if self.write_flush: # issue 580 if self.write_flush: # issue 580
self.writer.flush() # issue #482 self.writer.flush() # issue #482
@ -81,6 +132,24 @@ class TensorboardLogger(BaseLogger):
return epoch, env_step, gradient_step return epoch, env_step, gradient_step
@staticmethod
def restore_logged_data(log_path):
ea = event_accumulator.EventAccumulator(log_path)
ea.Reload()
def add_to_dict(dictionary, keys, value):
current_dict = dictionary
for key in keys[:-1]:
current_dict = current_dict.setdefault(key, {})
current_dict[keys[-1]] = value
data = {}
for key in ea.scalars.Keys():
split_keys = key.split('/')
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
return data
class BasicLogger(TensorboardLogger): class BasicLogger(TensorboardLogger):
"""BasicLogger has changed its name to TensorboardLogger in #427. """BasicLogger has changed its name to TensorboardLogger in #427.