logger updates
- introduced logger manager - loggers can reload logged data from disk
This commit is contained in:
parent
95cbfe6cdf
commit
32cd3b4357
@ -5,6 +5,8 @@ from typing import Literal, TypeAlias
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
||||||
|
from tianshou.utils.logger.base import LoggerManager
|
||||||
|
from tianshou.utils.logger.pandas_logger import PandasLogger
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TLogger: TypeAlias = BaseLogger
|
TLogger: TypeAlias = BaseLogger
|
||||||
@ -32,7 +34,7 @@ class LoggerFactory(ToStringMixin, ABC):
|
|||||||
class LoggerFactoryDefault(LoggerFactory):
|
class LoggerFactoryDefault(LoggerFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
|
logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard",
|
||||||
wandb_project: str | None = None,
|
wandb_project: str | None = None,
|
||||||
):
|
):
|
||||||
if logger_type == "wandb" and wandb_project is None:
|
if logger_type == "wandb" and wandb_project is None:
|
||||||
@ -47,18 +49,21 @@ class LoggerFactoryDefault(LoggerFactory):
|
|||||||
run_id: str | None,
|
run_id: str | None,
|
||||||
config_dict: dict,
|
config_dict: dict,
|
||||||
) -> TLogger:
|
) -> TLogger:
|
||||||
writer = SummaryWriter(log_dir)
|
if self.logger_type in ["wandb", "tensorboard"]:
|
||||||
writer.add_text(
|
writer = SummaryWriter(log_dir)
|
||||||
"args",
|
writer.add_text(
|
||||||
str(
|
"args",
|
||||||
dict(
|
str(
|
||||||
log_dir=log_dir,
|
dict(
|
||||||
logger_type=self.logger_type,
|
log_dir=log_dir,
|
||||||
wandb_project=self.wandb_project,
|
logger_type=self.logger_type,
|
||||||
|
wandb_project=self.wandb_project,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
match self.logger_type:
|
match self.logger_type:
|
||||||
|
case "pandas":
|
||||||
|
return PandasLogger(log_dir, exclude_arrays=False)
|
||||||
case "wandb":
|
case "wandb":
|
||||||
wandb_logger = WandbLogger(
|
wandb_logger = WandbLogger(
|
||||||
save_interval=1,
|
save_interval=1,
|
||||||
@ -73,3 +78,38 @@ class LoggerFactoryDefault(LoggerFactory):
|
|||||||
return TensorboardLogger(writer)
|
return TensorboardLogger(writer)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerManagerFactory(LoggerFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger_types: list[Literal["tensorboard", "wandb", "pandas"]] = ["tensorboard", "pandas"],
|
||||||
|
wandb_project: str | None = None,
|
||||||
|
):
|
||||||
|
self.logger_types = logger_types
|
||||||
|
self.wandb_project = wandb_project
|
||||||
|
|
||||||
|
self.factories = {
|
||||||
|
"pandas": LoggerFactoryDefault(logger_type="pandas"),
|
||||||
|
"wandb": LoggerFactoryDefault(logger_type="wandb", wandb_project=wandb_project),
|
||||||
|
"tensorboard": LoggerFactoryDefault(logger_type="tensorboard"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_logger(
|
||||||
|
self,
|
||||||
|
log_dir: str,
|
||||||
|
experiment_name: str,
|
||||||
|
run_id: str | None,
|
||||||
|
config_dict: dict,
|
||||||
|
) -> TLogger:
|
||||||
|
logger_manager = LoggerManager()
|
||||||
|
for logger_type in self.logger_types:
|
||||||
|
logger_manager.loggers.append(
|
||||||
|
self.factories[logger_type].create_logger(
|
||||||
|
log_dir,
|
||||||
|
experiment_name,
|
||||||
|
run_id,
|
||||||
|
config_dict,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return logger_manager
|
||||||
|
@ -37,12 +37,14 @@ class BaseLogger(ABC):
|
|||||||
test_interval: int = 1,
|
test_interval: int = 1,
|
||||||
update_interval: int = 1000,
|
update_interval: int = 1000,
|
||||||
info_interval: int = 1,
|
info_interval: int = 1,
|
||||||
|
exclude_arrays: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.train_interval = train_interval
|
self.train_interval = train_interval
|
||||||
self.test_interval = test_interval
|
self.test_interval = test_interval
|
||||||
self.update_interval = update_interval
|
self.update_interval = update_interval
|
||||||
self.info_interval = info_interval
|
self.info_interval = info_interval
|
||||||
|
self.exclude_arrays = exclude_arrays
|
||||||
self.last_log_train_step = -1
|
self.last_log_train_step = -1
|
||||||
self.last_log_test_step = -1
|
self.last_log_test_step = -1
|
||||||
self.last_log_update_step = -1
|
self.last_log_update_step = -1
|
||||||
@ -58,45 +60,8 @@ class BaseLogger(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_dict_for_logging(
|
def prepare_dict_for_logging(log_data: dict) -> dict:
|
||||||
input_dict: dict[str, Any],
|
return log_data
|
||||||
parent_key: str = "",
|
|
||||||
delimiter: str = "/",
|
|
||||||
exclude_arrays: bool = True,
|
|
||||||
) -> dict[str, VALID_LOG_VALS_TYPE]:
|
|
||||||
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
|
|
||||||
|
|
||||||
Filtering is performed with respect to valid logging data types.
|
|
||||||
|
|
||||||
:param input_dict: The nested dictionary to be flattened and filtered.
|
|
||||||
:param parent_key: The parent key used as a prefix before the input_dict keys.
|
|
||||||
:param delimiter: The delimiter used to separate the keys.
|
|
||||||
:param exclude_arrays: Whether to exclude numpy arrays from the output.
|
|
||||||
:return: A flattened dictionary where the keys are compressed and values are filtered.
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
def add_to_result(
|
|
||||||
cur_dict: dict,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
for key, value in cur_dict.items():
|
|
||||||
if exclude_arrays and isinstance(value, np.ndarray):
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_key = prefix + delimiter + key
|
|
||||||
new_key = new_key.lstrip(delimiter)
|
|
||||||
|
|
||||||
if isinstance(value, dict):
|
|
||||||
add_to_result(
|
|
||||||
value,
|
|
||||||
new_key,
|
|
||||||
)
|
|
||||||
elif isinstance(value, VALID_LOG_VALS):
|
|
||||||
result[new_key] = value
|
|
||||||
|
|
||||||
add_to_result(input_dict, prefix=parent_key)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def log_train_data(self, log_data: dict, step: int) -> None:
|
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||||
"""Use writer to log statistics generated during training.
|
"""Use writer to log statistics generated during training.
|
||||||
@ -106,8 +71,8 @@ class BaseLogger(ABC):
|
|||||||
"""
|
"""
|
||||||
# TODO: move interval check to calling method
|
# TODO: move interval check to calling method
|
||||||
if step - self.last_log_train_step >= self.train_interval:
|
if step - self.last_log_train_step >= self.train_interval:
|
||||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TRAIN.value)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write("train/env_step", step, log_data)
|
self.write("/".join([DataScope.TRAIN.value, "env_step"]), step, log_data)
|
||||||
self.last_log_train_step = step
|
self.last_log_train_step = step
|
||||||
|
|
||||||
def log_test_data(self, log_data: dict, step: int) -> None:
|
def log_test_data(self, log_data: dict, step: int) -> None:
|
||||||
@ -118,8 +83,8 @@ class BaseLogger(ABC):
|
|||||||
"""
|
"""
|
||||||
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
|
# TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer)
|
||||||
if step - self.last_log_test_step >= self.test_interval:
|
if step - self.last_log_test_step >= self.test_interval:
|
||||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TEST.value)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write(DataScope.TEST.value + "/env_step", step, log_data)
|
self.write("/".join([DataScope.TEST.value, "env_step"]), step, log_data)
|
||||||
self.last_log_test_step = step
|
self.last_log_test_step = step
|
||||||
|
|
||||||
def log_update_data(self, log_data: dict, step: int) -> None:
|
def log_update_data(self, log_data: dict, step: int) -> None:
|
||||||
@ -130,8 +95,8 @@ class BaseLogger(ABC):
|
|||||||
"""
|
"""
|
||||||
# TODO: move interval check to calling method
|
# TODO: move interval check to calling method
|
||||||
if step - self.last_log_update_step >= self.update_interval:
|
if step - self.last_log_update_step >= self.update_interval:
|
||||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.UPDATE.value)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write(DataScope.UPDATE.value + "/gradient_step", step, log_data)
|
self.write("/".join([DataScope.UPDATE.value, "gradient_step"]), step, log_data)
|
||||||
self.last_log_update_step = step
|
self.last_log_update_step = step
|
||||||
|
|
||||||
def log_info_data(self, log_data: dict, step: int) -> None:
|
def log_info_data(self, log_data: dict, step: int) -> None:
|
||||||
@ -143,8 +108,8 @@ class BaseLogger(ABC):
|
|||||||
if (
|
if (
|
||||||
step - self.last_log_info_step >= self.info_interval
|
step - self.last_log_info_step >= self.info_interval
|
||||||
): # TODO: move interval check to calling method
|
): # TODO: move interval check to calling method
|
||||||
log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.INFO.value)
|
log_data = self.prepare_dict_for_logging(log_data)
|
||||||
self.write(DataScope.INFO.value + "/epoch", step, log_data)
|
self.write("/".join([DataScope.INFO.value, "epoch"]), step, log_data)
|
||||||
self.last_log_info_step = step
|
self.last_log_info_step = step
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -174,6 +139,11 @@ class BaseLogger(ABC):
|
|||||||
:return: epoch, env_step, gradient_step.
|
:return: epoch, env_step, gradient_step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def restore_logged_data(log_path):
|
||||||
|
"""Return the logged data from the logger."""
|
||||||
|
|
||||||
|
|
||||||
class LazyLogger(BaseLogger):
|
class LazyLogger(BaseLogger):
|
||||||
"""A logger that does nothing. Used as the placeholder in trainer."""
|
"""A logger that does nothing. Used as the placeholder in trainer."""
|
||||||
@ -195,3 +165,58 @@ class LazyLogger(BaseLogger):
|
|||||||
|
|
||||||
def restore_data(self) -> tuple[int, int, int]:
|
def restore_data(self) -> tuple[int, int, int]:
|
||||||
return 0, 0, 0
|
return 0, 0, 0
|
||||||
|
|
||||||
|
def restore_logged_data(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerManager(BaseLogger):
|
||||||
|
"""A container of loggers that holds more than one logger."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.loggers = []
|
||||||
|
|
||||||
|
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
||||||
|
for logger in self.loggers:
|
||||||
|
data_copy = data.copy()
|
||||||
|
logger.write(step_type, step, data_copy)
|
||||||
|
|
||||||
|
def log_train_data(self, log_data: dict, step: int) -> None:
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log_train_data(log_data, step)
|
||||||
|
|
||||||
|
def log_test_data(self, log_data: dict, step: int) -> None:
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log_test_data(log_data, step)
|
||||||
|
|
||||||
|
def log_update_data(self, log_data: dict, step: int) -> None:
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log_update_data(log_data, step)
|
||||||
|
|
||||||
|
def log_info_data(self, log_data: dict, step: int) -> None:
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log_info_data(log_data, step)
|
||||||
|
|
||||||
|
def save_data(
|
||||||
|
self,
|
||||||
|
epoch: int,
|
||||||
|
env_step: int,
|
||||||
|
gradient_step: int,
|
||||||
|
save_checkpoint_fn: Callable[[int, int, int], str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||||
|
|
||||||
|
def restore_data(self) -> tuple[int, int, int]:
|
||||||
|
for logger in self.loggers:
|
||||||
|
epoch, env_step, gradient_step = logger.restore_data()
|
||||||
|
|
||||||
|
self.last_save_step = self.last_log_test_step = epoch
|
||||||
|
self.last_log_update_step = gradient_step
|
||||||
|
self.last_log_train_step = env_step
|
||||||
|
|
||||||
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
|
def restore_logged_data(self, log_path):
|
||||||
|
return self.loggers[0].restore_logged_data(log_path)
|
||||||
|
106
tianshou/utils/logger/pandas_logger.py
Normal file
106
tianshou/utils/logger/pandas_logger.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from tianshou.utils import BaseLogger, logging
|
||||||
|
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
class PandasLogger(BaseLogger):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_dir: str,
|
||||||
|
train_interval: int = 1000,
|
||||||
|
test_interval: int = 1,
|
||||||
|
update_interval: int = 1000,
|
||||||
|
info_interval: int = 1,
|
||||||
|
exclude_arrays: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(train_interval, test_interval, update_interval, info_interval, exclude_arrays)
|
||||||
|
self.log_path = log_dir
|
||||||
|
self.csv_name = os.path.join(self.log_path, "log.csv")
|
||||||
|
self.pkl_name = os.path.join(self.log_path, "log.pkl")
|
||||||
|
self.data = defaultdict(list)
|
||||||
|
self.last_save_step = -1
|
||||||
|
|
||||||
|
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
||||||
|
scope, step_name = step_type.split("/")
|
||||||
|
data[step_name] = step
|
||||||
|
self.data[scope].append(data)
|
||||||
|
|
||||||
|
def save_data(self, epoch: int, env_step: int, gradient_step: int,
|
||||||
|
save_checkpoint_fn: Callable[[int, int, int], str] | None = None) -> None:
|
||||||
|
self.last_save_step = epoch
|
||||||
|
# create and dump a dataframe
|
||||||
|
for k, v in self.data.items():
|
||||||
|
df = pd.DataFrame(v)
|
||||||
|
df.to_csv(os.path.join(self.log_path, k + "_log.csv"), index_label="index")
|
||||||
|
df.to_pickle(os.path.join(self.log_path, k + "_log.pkl"))
|
||||||
|
|
||||||
|
def restore_data(self) -> tuple[int, int, int]:
|
||||||
|
for scope in ["train", "test", "update", "info"]:
|
||||||
|
try:
|
||||||
|
self.data[scope].extend(list(pd.read_pickle(os.path.join(self.log_path, scope + "_log.pkl")).T.to_dict().values()))
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.warning(f"Failed to restore {scope} data")
|
||||||
|
|
||||||
|
try: # epoch / gradient_step
|
||||||
|
epoch = self.data["info"][-1]["epoch"]
|
||||||
|
self.last_save_step = self.last_log_test_step = epoch
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
epoch = 0
|
||||||
|
try:
|
||||||
|
gradient_step = self.data["update"][-1]["gradient_step"]
|
||||||
|
self.last_log_update_step = gradient_step
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
gradient_step = 0
|
||||||
|
try: # offline trainer doesn't have env_step
|
||||||
|
env_step = self.data["train"][-1]["env_step"]
|
||||||
|
self.last_log_train_step = env_step
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
env_step = 0
|
||||||
|
|
||||||
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restore_logged_data(log_path):
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
def merge_dicts(dict_list):
|
||||||
|
result = defaultdict(list)
|
||||||
|
for d in dict_list:
|
||||||
|
for key, value in d.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
result[key] = merge_dicts([result.get(key, {}), value])
|
||||||
|
elif isinstance(value, list):
|
||||||
|
result[key] = result.get(key, []) + value
|
||||||
|
else:
|
||||||
|
result[key].append(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
for scope in ["train", "test", "update", "info"]:
|
||||||
|
try:
|
||||||
|
dict_list = list(pd.read_pickle(os.path.join(log_path, scope + "_log.pkl")).T.to_dict().values())
|
||||||
|
data[scope] = merge_dicts(dict_list)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.warning(f"Failed to restore {scope} data")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def prepare_dict_for_logging(self, data: dict[str, Any]) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
|
"""Filter out matplotlib figures from the data."""
|
||||||
|
filtered_dict = data.copy()
|
||||||
|
|
||||||
|
def filter_dict(d):
|
||||||
|
for key, value in d.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
filter_dict(value)
|
||||||
|
elif isinstance(value, Figure):
|
||||||
|
filtered_dict.pop(key)
|
||||||
|
|
||||||
|
filter_dict(data)
|
||||||
|
return filtered_dict
|
@ -1,10 +1,13 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib.figure import Figure
|
||||||
from tensorboard.backend.event_processing import event_accumulator
|
from tensorboard.backend.event_processing import event_accumulator
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger
|
from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger, VALID_LOG_VALS
|
||||||
from tianshou.utils.warning import deprecation
|
from tianshou.utils.warning import deprecation
|
||||||
|
|
||||||
|
|
||||||
@ -38,9 +41,57 @@ class TensorboardLogger(BaseLogger):
|
|||||||
self.last_save_step = -1
|
self.last_save_step = -1
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
|
|
||||||
def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None:
|
@staticmethod
|
||||||
|
def prepare_dict_for_logging(
|
||||||
|
input_dict: dict[str, Any],
|
||||||
|
parent_key: str = "",
|
||||||
|
delimiter: str = "/",
|
||||||
|
exclude_arrays: bool = True,
|
||||||
|
) -> dict[str, VALID_LOG_VALS_TYPE]:
|
||||||
|
"""Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys.
|
||||||
|
|
||||||
|
Filtering is performed with respect to valid logging data types.
|
||||||
|
|
||||||
|
:param input_dict: The nested dictionary to be flattened and filtered.
|
||||||
|
:param parent_key: The parent key used as a prefix before the input_dict keys.
|
||||||
|
:param delimiter: The delimiter used to separate the keys.
|
||||||
|
:param exclude_arrays: Whether to exclude numpy arrays from the output.
|
||||||
|
:return: A flattened dictionary where the keys are compressed and values are filtered.
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
def add_to_result(
|
||||||
|
cur_dict: dict,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
for key, value in cur_dict.items():
|
||||||
|
if exclude_arrays and isinstance(value, np.ndarray):
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_key = prefix + delimiter + key
|
||||||
|
new_key = new_key.lstrip(delimiter)
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
add_to_result(
|
||||||
|
value,
|
||||||
|
new_key,
|
||||||
|
)
|
||||||
|
elif isinstance(value, VALID_LOG_VALS):
|
||||||
|
result[new_key] = value
|
||||||
|
|
||||||
|
add_to_result(input_dict, prefix=parent_key)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def write(self, step_type: str, step: int, data: dict[str, Any]) -> None:
|
||||||
|
scope, step_name = step_type.split("/")
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
self.writer.add_scalar(k, v, global_step=step)
|
scope_key = '/'.join([scope, k])
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
self.writer.add_histogram(scope_key, v, global_step=step, bins="auto")
|
||||||
|
elif isinstance(v, Figure):
|
||||||
|
self.writer.add_figure(scope_key, v, global_step=step)
|
||||||
|
else:
|
||||||
|
self.writer.add_scalar(scope_key, v, global_step=step)
|
||||||
if self.write_flush: # issue 580
|
if self.write_flush: # issue 580
|
||||||
self.writer.flush() # issue #482
|
self.writer.flush() # issue #482
|
||||||
|
|
||||||
@ -81,6 +132,24 @@ class TensorboardLogger(BaseLogger):
|
|||||||
|
|
||||||
return epoch, env_step, gradient_step
|
return epoch, env_step, gradient_step
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restore_logged_data(log_path):
|
||||||
|
ea = event_accumulator.EventAccumulator(log_path)
|
||||||
|
ea.Reload()
|
||||||
|
|
||||||
|
def add_to_dict(dictionary, keys, value):
|
||||||
|
current_dict = dictionary
|
||||||
|
for key in keys[:-1]:
|
||||||
|
current_dict = current_dict.setdefault(key, {})
|
||||||
|
current_dict[keys[-1]] = value
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
for key in ea.scalars.Keys():
|
||||||
|
split_keys = key.split('/')
|
||||||
|
add_to_dict(data, split_keys, np.array([s.value for s in ea.scalars.Items(key)]))
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class BasicLogger(TensorboardLogger):
|
class BasicLogger(TensorboardLogger):
|
||||||
"""BasicLogger has changed its name to TensorboardLogger in #427.
|
"""BasicLogger has changed its name to TensorboardLogger in #427.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user