Tianshou/tianshou/utils/log_tools.py
rocknamx c97aa4065e
add singleton pattern version of summary_writter (#230)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2020-10-31 16:38:54 +08:00

48 lines
1.5 KiB
Python

import threading
from torch.utils import tensorboard
from typing import Any, Dict, Optional
class SummaryWriter(tensorboard.SummaryWriter):
"""A more convenient Summary Writer(`tensorboard.SummaryWriter`).
You can get the same instance of summary writer everywhere after you
created one.
::
>>> writer1 = SummaryWriter.get_instance(
key="first", log_dir="log/test_sw/first")
>>> writer2 = SummaryWriter.get_instance()
>>> writer1 is writer2
True
>>> writer4 = SummaryWriter.get_instance(
key="second", log_dir="log/test_sw/second")
>>> writer5 = SummaryWriter.get_instance(key="second")
>>> writer1 is not writer4
True
>>> writer4 is writer5
True
"""
_mutex_lock = threading.Lock()
_default_key: str
_instance: Optional[Dict[str, "SummaryWriter"]] = None
@classmethod
def get_instance(
cls,
key: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> "SummaryWriter":
"""Get instance of torch.utils.tensorboard.SummaryWriter by key."""
with SummaryWriter._mutex_lock:
if key is None:
key = SummaryWriter._default_key
if SummaryWriter._instance is None:
SummaryWriter._instance = {}
SummaryWriter._default_key = key
if key not in SummaryWriter._instance.keys():
SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs)
return SummaryWriter._instance[key]