66 lines
2.2 KiB
Python
Raw Normal View History

import os
2023-09-20 09:29:34 +02:00
from abc import ABC, abstractmethod
from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = TensorboardLogger | WandbLogger
class LoggerFactory(ToStringMixin, ABC):
@abstractmethod
def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
) -> TLogger:
""":param log_dir: path to the directory in which log data is to be stored
:param experiment_name: the name of the job, which may contain os.path.sep
:param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger
:param config_dict: a dictionary with data that is to be logged
:return: the logger
"""
class DefaultLoggerFactory(LoggerFactory):
def __init__(
self,
logger_type: Literal["tensorboard", "wandb"] = "tensorboard",
wandb_project: str | None = None,
):
2023-09-20 15:10:19 +02:00
if logger_type == "wandb" and wandb_project is None:
raise ValueError("Must provide 'wandb_project'")
2023-09-20 15:10:19 +02:00
self.logger_type = logger_type
self.wandb_project = wandb_project
def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
) -> TLogger:
writer = SummaryWriter(log_dir)
writer.add_text(
"args",
str(
dict(
log_dir=log_dir,
logger_type=self.logger_type,
wandb_project=self.wandb_project,
),
),
)
logger: TLogger
2023-09-20 15:10:19 +02:00
if self.logger_type == "wandb":
logger = WandbLogger(
save_interval=1,
name=experiment_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config_dict,
2023-09-20 15:10:19 +02:00
project=self.wandb_project,
)
logger.load(writer)
2023-09-20 15:10:19 +02:00
elif self.logger_type == "tensorboard":
logger = TensorboardLogger(writer)
else:
2023-09-20 15:10:19 +02:00
raise ValueError(f"Unknown logger type '{self.logger_type}'")
return logger