50 lines
1.5 KiB
Python

from abc import ABC, abstractmethod
import os
from dataclasses import dataclass
from typing import Union, Optional
from torch.utils.tensorboard import SummaryWriter
from tianshou.config import LoggerConfig
from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = Union[TensorboardLogger, WandbLogger]
@dataclass
class Logger:
logger: TLogger
log_path: str
class LoggerFactory(ABC):
@abstractmethod
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
pass
class DefaultLoggerFactory(LoggerFactory):
def __init__(self, config: LoggerConfig):
self.config = config
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
writer = SummaryWriter(self.config.logdir)
writer.add_text("args", str(self.config))
if self.config.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config_dict,
project=self.config.wandb_project,
)
logger.load(writer)
elif self.config.logger == "tensorboard":
logger = TensorboardLogger(writer)
else:
raise ValueError(f"Unknown logger: {self.config.logger}")
log_path = os.path.join(self.config.logdir, log_name)
os.makedirs(log_path, exist_ok=True)
return Logger(logger=logger, log_path=log_path)