Tianshou/tianshou/utils/logger/pandas_logger.py

125 lines
4.3 KiB
Python
Raw Normal View History

import os
from collections import defaultdict
2024-03-12 15:01:50 +01:00
from collections.abc import Callable
from typing import Any
import pandas as pd
from tianshou.utils import BaseLogger, logging
2024-03-12 15:01:50 +01:00
from tianshou.utils.logger.base import VALID_LOG_VALS, VALID_LOG_VALS_TYPE
class PandasLogger(BaseLogger):
def __init__(
self,
log_dir: str,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
info_interval: int = 1,
exclude_arrays: bool = True,
) -> None:
2024-03-12 15:01:50 +01:00
super().__init__(
train_interval,
test_interval,
update_interval,
info_interval,
exclude_arrays,
)
self.log_path = log_dir
self.csv_name = os.path.join(self.log_path, "log.csv")
self.pkl_name = os.path.join(self.log_path, "log.pkl")
2024-03-12 15:01:50 +01:00
self.data: dict[str, list] = defaultdict(list)
self.last_save_step = -1
2024-03-12 15:01:50 +01:00
@staticmethod
def prepare_dict_for_logging(data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
"""Removes invalid data types from the log data."""
filtered_dict = data.copy()
def filter_dict(data_dict: dict[str, Any]) -> None:
"""Filter in place."""
for key, value in data_dict.items():
if isinstance(value, VALID_LOG_VALS):
filter_dict(value)
else:
filtered_dict.pop(key)
filter_dict(data)
return filtered_dict
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
scope, step_name = step_type.split("/")
data[step_name] = step
self.data[scope].append(data)
2024-03-12 15:01:50 +01:00
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
) -> None:
self.last_save_step = epoch
# create and dump a dataframe
for k, v in self.data.items():
df = pd.DataFrame(v)
df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index")
df.to_pickle(os.path.join(self.log_path, k + "_log.pkl"))
def restore_data(self) -> tuple[int, int, int]:
for scope in ["train", "test", "update", "info"]:
try:
2024-03-12 15:01:50 +01:00
self.data[scope].extend(
list(
pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl"))
.T.to_dict()
.values(),
),
)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
try: # epoch / gradient_step
epoch = self.data["info"][-1]["epoch"]
self.last_save_step = self.last_log_test_step = epoch
except (KeyError, IndexError):
epoch = 0
try:
gradient_step = self.data["update"][-1]["gradient_step"]
self.last_log_update_step = gradient_step
except (KeyError, IndexError):
gradient_step = 0
try: # offline trainer doesn't have env_step
env_step = self.data["train"][-1]["env_step"]
self.last_log_train_step = env_step
except (KeyError, IndexError):
env_step = 0
return epoch, env_step, gradient_step
2024-03-12 15:01:50 +01:00
def restore_logged_data(self, log_path: str) -> dict[str, Any]:
data = {}
2024-03-12 15:01:50 +01:00
def merge_dicts(list_of_dicts: list[dict]) -> dict[str, Any]:
result: dict[str, Any] = defaultdict(list)
2024-03-12 11:31:41 +01:00
for d in list_of_dicts:
for key, value in d.items():
if isinstance(value, dict):
result[key] = merge_dicts([result.get(key, {}), value])
elif isinstance(value, list):
result[key] = result.get(key, []) + value
else:
result[key].append(value)
return result
for scope in ["train", "test", "update", "info"]:
try:
2024-03-12 15:01:50 +01:00
dict_list = list(
pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values(),
)
data[scope] = merge_dicts(dict_list)
except FileNotFoundError:
logging.warning(f"Failed to restore {scope} data")
return data