logger updates
- introduced logger manager - loggers can reload logged data from disk
This commit is contained in:
parent
95cbfe6cdf
commit
32cd3b4357
@ -5,6 +5,8 @@ from typing import Literal, TypeAlias
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.logger.base import LoggerManager
|
||||
from tianshou.utils.logger.pandas_logger import PandasLogger
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TLogger: TypeAlias = BaseLogger
|
||||
@ -32,7 +34,7 @@ class LoggerFactory(ToStringMixin, ABC):
|
||||
class LoggerFactoryDefault(LoggerFactory):
|
||||
def __init__(
|
||||
self,
|
||||
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
||||
logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard",
|
||||
wandb_project: str | None = None,
|
||||
):
|
||||
if logger_type == "wandb" and wandb_project is None:
|
||||
@ -47,18 +49,21 @@ class LoggerFactoryDefault(LoggerFactory):
|
||||
run_id: str | None,
|
||||
config_dict: dict,
|
||||
) -> TLogger:
|
||||
writer = SummaryWriter(log_dir)
|
||||
writer.add_text(
|
||||
"args",
|
||||
str(
|
||||
dict(
|
||||
log_dir=log_dir,
|
||||
logger_type=self.logger_type,
|
||||
wandb_project=self.wandb_project,
|
||||
if self.logger_type in ["wandb", "tensorboard"]:
|
||||
writer = SummaryWriter(log_dir)
|
||||
writer.add_text(
|
||||
"args",
|
||||
str(
|
||||
dict(
|
||||
log_dir=log_dir,
|
||||
logger_type=self.logger_type,
|
||||
wandb_project=self.wandb_project,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
match self.logger_type:
|
||||
case "pandas":
|
||||
return PandasLogger(log_dir, exclude_arrays=False)
|
||||
case "wandb":
|
||||
wandb_logger = WandbLogger(
|
||||
save_interval=1,
|
||||
@ -73,3 +78,38 @@ class LoggerFactoryDefault(LoggerFactory):
|
||||
return TensorboardLogger(writer)
|
||||
case _:
|
||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||
|
||||
|
||||
class LoggerManagerFactory(LoggerFactory):
|
||||
def __init__(
|
||||
self,
|
||||
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
|
||||
wandb_project: str | None = None,
|
||||
):
|
||||
self.logger_types = logger_types
|
||||
self.wandb_project = wandb_project
|
||||
|
||||
self.factories = {
|
||||
"pandas": LoggerFactoryDefault(logger_type="pandas"),
|
||||
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
|
||||
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
|
||||
}
|
||||
|
||||
def create_logger(
|
||||
self,
|
||||
log_dir: str,
|
||||
experiment_name: str,
|
||||
run_id: str | None,
|
||||
config_dict: dict,
|
||||
) -> TLogger:
|
||||
logger_manager = LoggerManager()
|
||||
for logger_type in self.logger_types:
|
||||
logger_manager.loggers.append(
|
||||
self.factories[logger_type].create_logger(
|
||||
log_dir,
|
||||
experiment_name,
|
||||
run_id,
|
||||
config_dict,
|
||||
)
|
||||
)
|
||||
return logger_manager
|
||||
|
@ -37,12 +37,14 @@ class BaseLogger(ABC):
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
info_interval: int = 1,
|
||||
exclude_arrays: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.train_interval = train_interval
|
||||
self.test_interval = test_interval
|
||||
self.update_interval = update_interval
|
||||
self.info_interval = info_interval
|
||||
self.exclude_arrays = exclude_arrays
|
||||
self.last_log_train_step = -1
|
||||
self.last_log_test_step = -1
|
||||
self.last_log_update_step = -1
|
||||
@ -58,45 +60,8 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def prepare_dict_for_logging(
|
||||
input_dict: dict[str, Any],
|
||||
parent_key: str = "",
|
||||
delimiter: str = "/",
|
||||
exclude_arrays: bool = True,
|
||||
) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
|
||||
|
||||
Filtering is performed with respect to valid logging data types.
|
||||
|
||||
:param input_dict: The nested dictionary to be flattened and filtered.
|
||||
:param parent_key: The parent key used as a prefix before the input_dict keys.
|
||||
:param delimiter: The delimiter used to separate the keys.
|
||||
:param exclude_arrays: Whether to exclude numpy arrays from the output.
|
||||
:return: A flattened dictionary where the keys are compressed and values are filtered.
|
||||
"""
|
||||
result = {}
|
||||
|
||||
def add_to_result(
|
||||
cur_dict: dict,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
for key, value in cur_dict.items():
|
||||
if exclude_arrays and isinstance(value, np.ndarray):
|
||||
continue
|
||||
|
||||
new_key = prefix + delimiter + key
|
||||
new_key = new_key.lstrip(delimiter)
|
||||
|
||||
if isinstance(value, dict):
|
||||
add_to_result(
|
||||
value,
|
||||
new_key,
|
||||
)
|
||||
elif isinstance(value, VALID_LOG_VALS):
|
||||
result[new_key] = value
|
||||
|
||||
add_to_result(input_dict, prefix=parent_key)
|
||||
return result
|
||||
def prepare_dict_for_logging(log_data: dict) -> dict:
|
||||
return log_data
|
||||
|
||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||
"""Use writer to log statistics generated during training.
|
||||
@ -106,8 +71,8 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
# TODO: move interval check to calling method
|
||||
if step - self.last_log_train_step >= self.train_interval:
|
||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TRAIN.value)
|
||||
self.write("train/env_step", step, log_data)
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data)
|
||||
self.last_log_train_step = step
|
||||
|
||||
def log_test_data(self, log_data: dict, step: int) -> None:
|
||||
@ -118,8 +83,8 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
|
||||
if step - self.last_log_test_step >= self.test_interval:
|
||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TEST.value)
|
||||
self.write(DataScope.TEST.value + "/env_step", step, log_data)
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data)
|
||||
self.last_log_test_step = step
|
||||
|
||||
def log_update_data(self, log_data: dict, step: int) -> None:
|
||||
@ -130,8 +95,8 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
# TODO: move interval check to calling method
|
||||
if step - self.last_log_update_step >= self.update_interval:
|
||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.UPDATE.value)
|
||||
self.write(DataScope.UPDATE.value + "/gradient_step", step, log_data)
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data)
|
||||
self.last_log_update_step = step
|
||||
|
||||
def log_info_data(self, log_data: dict, step: int) -> None:
|
||||
@ -143,8 +108,8 @@ class BaseLogger(ABC):
|
||||
if (
|
||||
step - self.last_log_info_step >= self.info_interval
|
||||
): # TODO: move interval check to calling method
|
||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.INFO.value)
|
||||
self.write(DataScope.INFO.value + "/epoch", step, log_data)
|
||||
log_data = self.prepare_dict_for_logging(log_data)
|
||||
self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data)
|
||||
self.last_log_info_step = step
|
||||
|
||||
@abstractmethod
|
||||
@ -174,6 +139,11 @@ class BaseLogger(ABC):
|
||||
:return: epoch, env_step, gradient_step.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def restore_logged_data(log_path):
|
||||
"""Return the logged data from the logger."""
|
||||
|
||||
|
||||
class LazyLogger(BaseLogger):
|
||||
"""A logger that does nothing. Used as the placeholder in trainer."""
|
||||
@ -195,3 +165,58 @@ class LazyLogger(BaseLogger):
|
||||
|
||||
def restore_data(self) -> tuple[int, int, int]:
|
||||
return 0, 0, 0
|
||||
|
||||
def restore_logged_data(self):
|
||||
return None
|
||||
|
||||
|
||||
class LoggerManager(BaseLogger):
|
||||
"""A container of loggers that holds more than one logger."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loggers = []
|
||||
|
||||
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
||||
for logger in self.loggers:
|
||||
data_copy = data.copy()
|
||||
logger.write(step_type, step, data_copy)
|
||||
|
||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_train_data(log_data, step)
|
||||
|
||||
def log_test_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_test_data(log_data, step)
|
||||
|
||||
def log_update_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_update_data(log_data, step)
|
||||
|
||||
def log_info_data(self, log_data: dict, step: int) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.log_info_data(log_data, step)
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
||||
) -> None:
|
||||
for logger in self.loggers:
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
|
||||
def restore_data(self) -> tuple[int, int, int]:
|
||||
for logger in self.loggers:
|
||||
epoch, env_step, gradient_step = logger.restore_data()
|
||||
|
||||
self.last_save_step = self.last_log_test_step = epoch
|
||||
self.last_log_update_step = gradient_step
|
||||
self.last_log_train_step = env_step
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
def restore_logged_data(self, log_path):
|
||||
return self.loggers[0].restore_logged_data(log_path)
|
||||
|
106
tianshou/utils/logger/pandas_logger.py
Normal file
106
tianshou/utils/logger/pandas_logger.py
Normal file
@ -0,0 +1,106 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Callable, Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from tianshou.utils import BaseLogger, logging
|
||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
|
||||
|
||||
|
||||
class PandasLogger(BaseLogger):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: str,
|
||||
train_interval: int = 1000,
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
info_interval: int = 1,
|
||||
exclude_arrays: bool = True,
|
||||
) -> None:
|
||||
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
|
||||
self.log_path = log_dir
|
||||
self.csv_name = os.path.join(self.log_path, "log.csv")
|
||||
self.pkl_name = os.path.join(self.log_path, "log.pkl")
|
||||
self.data = defaultdict(list)
|
||||
self.last_save_step = -1
|
||||
|
||||
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
||||
scope, step_name = step_type.split("/")
|
||||
data[step_name] = step
|
||||
self.data[scope].append(data)
|
||||
|
||||
def save_data(self, epoch: int, env_step: int, gradient_step: int,
|
||||
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
|
||||
self.last_save_step = epoch
|
||||
# create and dump a dataframe
|
||||
for k, v in self.data.items():
|
||||
df = pd.DataFrame(v)
|
||||
df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index")
|
||||
df.to_pickle(os.path.join(self.log_path, k + "_log.pkl"))
|
||||
|
||||
def restore_data(self) -> tuple[int, int, int]:
|
||||
for scope in ["train", "test", "update", "info"]:
|
||||
try:
|
||||
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Failed to restore {scope} data")
|
||||
|
||||
try: # epoch / gradient_step
|
||||
epoch = self.data["info"][-1]["epoch"]
|
||||
self.last_save_step = self.last_log_test_step = epoch
|
||||
except (KeyError, IndexError):
|
||||
epoch = 0
|
||||
try:
|
||||
gradient_step = self.data["update"][-1]["gradient_step"]
|
||||
self.last_log_update_step = gradient_step
|
||||
except (KeyError, IndexError):
|
||||
gradient_step = 0
|
||||
try: # offline trainer doesn't have env_step
|
||||
env_step = self.data["train"][-1]["env_step"]
|
||||
self.last_log_train_step = env_step
|
||||
except (KeyError, IndexError):
|
||||
env_step = 0
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
@staticmethod
|
||||
def restore_logged_data(log_path):
|
||||
data = {}
|
||||
|
||||
def merge_dicts(dict_list):
|
||||
result = defaultdict(list)
|
||||
for d in dict_list:
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = merge_dicts([result.get(key, {}), value])
|
||||
elif isinstance(value, list):
|
||||
result[key] = result.get(key, []) + value
|
||||
else:
|
||||
result[key].append(value)
|
||||
return result
|
||||
|
||||
for scope in ["train", "test", "update", "info"]:
|
||||
try:
|
||||
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
|
||||
data[scope] = merge_dicts(dict_list)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Failed to restore {scope} data")
|
||||
return data
|
||||
|
||||
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
"""Filter out matplotlib figures from the data."""
|
||||
filtered_dict = data.copy()
|
||||
|
||||
def filter_dict(d):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
filter_dict(value)
|
||||
elif isinstance(value, Figure):
|
||||
filtered_dict.pop(key)
|
||||
|
||||
filter_dict(data)
|
||||
return filtered_dict
|
@ -1,10 +1,13 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from matplotlib.figure import Figure
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger
|
||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS
|
||||
from tianshou.utils.warning import deprecation
|
||||
|
||||
|
||||
@ -38,9 +41,57 @@ class TensorboardLogger(BaseLogger):
|
||||
self.last_save_step = -1
|
||||
self.writer = writer
|
||||
|
||||
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
||||
@staticmethod
|
||||
def prepare_dict_for_logging(
|
||||
input_dict: dict[str, Any],
|
||||
parent_key: str = "",
|
||||
delimiter: str = "/",
|
||||
exclude_arrays: bool = True,
|
||||
) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
|
||||
|
||||
Filtering is performed with respect to valid logging data types.
|
||||
|
||||
:param input_dict: The nested dictionary to be flattened and filtered.
|
||||
:param parent_key: The parent key used as a prefix before the input_dict keys.
|
||||
:param delimiter: The delimiter used to separate the keys.
|
||||
:param exclude_arrays: Whether to exclude numpy arrays from the output.
|
||||
:return: A flattened dictionary where the keys are compressed and values are filtered.
|
||||
"""
|
||||
result = {}
|
||||
|
||||
def add_to_result(
|
||||
cur_dict: dict,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
for key, value in cur_dict.items():
|
||||
if exclude_arrays and isinstance(value, np.ndarray):
|
||||
continue
|
||||
|
||||
new_key = prefix + delimiter + key
|
||||
new_key = new_key.lstrip(delimiter)
|
||||
|
||||
if isinstance(value, dict):
|
||||
add_to_result(
|
||||
value,
|
||||
new_key,
|
||||
)
|
||||
elif isinstance(value, VALID_LOG_VALS):
|
||||
result[new_key] = value
|
||||
|
||||
add_to_result(input_dict, prefix=parent_key)
|
||||
return result
|
||||
|
||||
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
||||
scope, step_name = step_type.split("/")
|
||||
for k, v in data.items():
|
||||
self.writer.add_scalar(k, v, global_step=step)
|
||||
scope_key = '/'.join([scope, k])
|
||||
if isinstance(v, np.ndarray):
|
||||
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
|
||||
elif isinstance(v, Figure):
|
||||
self.writer.add_figure(scope_key, v, global_step=step)
|
||||
else:
|
||||
self.writer.add_scalar(scope_key, v, global_step=step)
|
||||
if self.write_flush: # issue 580
|
||||
self.writer.flush() # issue #482
|
||||
|
||||
@ -81,6 +132,24 @@ class TensorboardLogger(BaseLogger):
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
@staticmethod
|
||||
def restore_logged_data(log_path):
|
||||
ea = event_accumulator.EventAccumulator(log_path)
|
||||
ea.Reload()
|
||||
|
||||
def add_to_dict(dictionary, keys, value):
|
||||
current_dict = dictionary
|
||||
for key in keys[:-1]:
|
||||
current_dict = current_dict.setdefault(key, {})
|
||||
current_dict[keys[-1]] = value
|
||||
|
||||
data = {}
|
||||
for key in ea.scalars.Keys():
|
||||
split_keys = key.split('/')
|
||||
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class BasicLogger(TensorboardLogger):
|
||||
"""BasicLogger has changed its name to TensorboardLogger in #427.
|
||||
|
Loading…
x
Reference in New Issue
Block a user