50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
from abc import ABC, abstractmethod
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Union, Optional
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from tianshou.config import LoggerConfig
|
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
|
|
|
|
|
TLogger = Union[TensorboardLogger, WandbLogger]
|
|
|
|
|
|
@dataclass
|
|
class Logger:
|
|
logger: TLogger
|
|
log_path: str
|
|
|
|
|
|
class LoggerFactory(ABC):
|
|
@abstractmethod
|
|
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
|
pass
|
|
|
|
|
|
class DefaultLoggerFactory(LoggerFactory):
|
|
def __init__(self, config: LoggerConfig):
|
|
self.config = config
|
|
|
|
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
|
writer = SummaryWriter(self.config.logdir)
|
|
writer.add_text("args", str(self.config))
|
|
if self.config.logger == "wandb":
|
|
logger = WandbLogger(
|
|
save_interval=1,
|
|
name=log_name.replace(os.path.sep, "__"),
|
|
run_id=run_id,
|
|
config=config_dict,
|
|
project=self.config.wandb_project,
|
|
)
|
|
logger.load(writer)
|
|
elif self.config.logger == "tensorboard":
|
|
logger = TensorboardLogger(writer)
|
|
else:
|
|
raise ValueError(f"Unknown logger: {self.config.logger}")
|
|
log_path = os.path.join(self.config.logdir, log_name)
|
|
os.makedirs(log_path, exist_ok=True)
|
|
return Logger(logger=logger, log_path=log_path)
|