Tianshou/tianshou/utils/logger/pandas_logger.py
Maximilian Huettenrauch 32cd3b4357 logger updates
- introduced logger manager
- loggers can reload logged data from disk
2024-03-11 10:29:17 +01:00

107 lines
4.0 KiB
Python

import os
from collections import defaultdict
from functools import partial
from typing import Callable, Any
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from tianshou.utils import BaseLogger, logging
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
class PandasLogger(BaseLogger):
def __init__(
self,
log_dir: str,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
info_interval: int = 1,
exclude_arrays: bool = True,
) -> None:
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
self.log_path = log_dir
self.csv_name = os.path.join(self.log_path, "log.csv")
self.pkl_name = os.path.join(self.log_path, "log.pkl")
self.data = defaultdict(list)
self.last_save_step = -1
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/")
data[step_name] = step
self.data[scope].append(data)
def save_data(self, epoch: int, env_step: int, gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
self.last_save_step = epoch
# create and dump a dataframe
for k, v in self.data.items():
df = pd.DataFrame(v)
df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index")
df.to_pickle(os.path.join(self.log_path, k + "_log.pkl"))
def restore_data(self) -> tuple[int, int, int]:
for scope in ["train", "test", "update", "info"]:
try:
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
try: # epoch / gradient_step
epoch = self.data["info"][-1]["epoch"]
self.last_save_step = self.last_log_test_step = epoch
except (KeyError, IndexError):
epoch = 0
try:
gradient_step = self.data["update"][-1]["gradient_step"]
self.last_log_update_step = gradient_step
except (KeyError, IndexError):
gradient_step = 0
try: # offline trainer doesn't have env_step
env_step = self.data["train"][-1]["env_step"]
self.last_log_train_step = env_step
except (KeyError, IndexError):
env_step = 0
return epoch, env_step, gradient_step
@staticmethod
def restore_logged_data(log_path):
data = {}
def merge_dicts(dict_list):
result = defaultdict(list)
for d in dict_list:
for key, value in d.items():
if isinstance(value, dict):
result[key] = merge_dicts([result.get(key, {}), value])
elif isinstance(value, list):
result[key] = result.get(key, []) + value
else:
result[key].append(value)
return result
for scope in ["train", "test", "update", "info"]:
try:
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
data[scope] = merge_dicts(dict_list)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
return data
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Filter out matplotlib figures from the data."""
filtered_dict = data.copy()
def filter_dict(d):
for key, value in d.items():
if isinstance(value, dict):
filter_dict(value)
elif isinstance(value, Figure):
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict