import os from collections import defaultdict from functools import partial from typing import Callable, Any import numpy as np import pandas as pd from matplotlib.figure import Figure from tianshou.utils import BaseLogger, logging from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE class PandasLogger(BaseLogger): def __init__( self, log_dir: str, train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, info_interval: int = 1, exclude_arrays: bool = True, ) -> None: super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays) self.log_path = log_dir self.csv_name = os.path.join(self.log_path, "log.csv") self.pkl_name = os.path.join(self.log_path, "log.pkl") self.data = defaultdict(list) self.last_save_step = -1 def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: scope, step_name = step_type.split("/") data[step_name] = step self.data[scope].append(data) def save_data(self, epoch: int, env_step: int, gradient_step: int, save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None: self.last_save_step = epoch # create and dump a dataframe for k, v in self.data.items(): df = pd.DataFrame(v) df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index") df.to_pickle(os.path.join(self.log_path, k + "_log.pkl")) def restore_data(self) -> tuple[int, int, int]: for scope in ["train", "test", "update", "info"]: try: self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values())) except FileNotFoundError: logging.warning(f"Failed to restore {scope} data") try: # epoch / gradient_step epoch = self.data["info"][-1]["epoch"] self.last_save_step = self.last_log_test_step = epoch except (KeyError, IndexError): epoch = 0 try: gradient_step = self.data["update"][-1]["gradient_step"] self.last_log_update_step = gradient_step except (KeyError, IndexError): gradient_step = 0 try: # offline trainer doesn't have env_step env_step = self.data["train"][-1]["env_step"] self.last_log_train_step = env_step except (KeyError, IndexError): env_step = 0 return epoch, env_step, gradient_step @staticmethod def restore_logged_data(log_path): data = {} def merge_dicts(list_of_dicts): result = defaultdict(list) for d in list_of_dicts: for key, value in d.items(): if isinstance(value, dict): result[key] = merge_dicts([result.get(key, {}), value]) elif isinstance(value, list): result[key] = result.get(key, []) + value else: result[key].append(value) return result for scope in ["train", "test", "update", "info"]: try: dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values()) data[scope] = merge_dicts(dict_list) except FileNotFoundError: logging.warning(f"Failed to restore {scope} data") return data def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]: """Filter out matplotlib figures from the data.""" filtered_dict = data.copy() def filter_dict(d): for key, value in d.items(): if isinstance(value, dict): filter_dict(value) elif isinstance(value, Figure): filtered_dict.pop(key) filter_dict(data) return filtered_dict