add singleton pattern version of summary_writter (#230)

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
rocknamx 2020-10-31 16:38:54 +08:00 committed by GitHub
parent b364f1a26f
commit c97aa4065e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 6 deletions

View File

@ -2,9 +2,10 @@ import torch
import numpy as np
from tianshou.utils import MovAvg
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import DQN
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -62,7 +63,25 @@ def test_net():
assert list(net(data)[0].shape) == expect_output_shape
def test_summary_writer():
# get first instance by key of `default` or your own key
writer1 = SummaryWriter.get_instance(
key="first", log_dir="log/test_sw/first")
assert writer1.log_dir == "log/test_sw/first"
writer2 = SummaryWriter.get_instance()
assert writer1 is writer2
# create new instance by specify a new key
writer3 = SummaryWriter.get_instance(
key="second", log_dir="log/test_sw/second")
assert writer3.log_dir == "log/test_sw/second"
writer4 = SummaryWriter.get_instance(key="second")
assert writer3 is writer4
assert writer1 is not writer3
assert writer1.log_dir != writer4.log_dir
if __name__ == '__main__':
test_noise()
test_moving_average()
test_net()
test_summary_writer()

View File

@ -124,6 +124,7 @@ def test_dqn(args=get_args()):
def test_pdqn(args=get_args()):
args.prioritized_replay = 1
args.gamma = .95
args.seed = 1
test_dqn(args)

View File

@ -32,7 +32,7 @@ class ShArray:
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(
_NP_TO_CT[dtype.type], # type: ignore
_NP_TO_CT[dtype.type],
int(np.prod(shape)),
)
self.dtype = dtype

View File

@ -178,10 +178,11 @@ class PPOPolicy(PGPolicy):
losses.append(loss.item())
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
list(self.actor.parameters())
+ list(self.critic.parameters()),
self._max_grad_norm)
if self._max_grad_norm:
nn.utils.clip_grad_norm_(
list(self.actor.parameters())
+ list(self.critic.parameters()),
self._max_grad_norm)
self.optim.step()
return {
"loss": losses,

View File

@ -1,7 +1,9 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg
from tianshou.utils.log_tools import SummaryWriter
__all__ = [
"MovAvg",
"tqdm_config",
"SummaryWriter",
]

View File

@ -0,0 +1,47 @@
import threading
from torch.utils import tensorboard
from typing import Any, Dict, Optional
class SummaryWriter(tensorboard.SummaryWriter):
"""A more convenient Summary Writer(`tensorboard.SummaryWriter`).
You can get the same instance of summary writer everywhere after you
created one.
::
>>> writer1 = SummaryWriter.get_instance(
key="first", log_dir="log/test_sw/first")
>>> writer2 = SummaryWriter.get_instance()
>>> writer1 is writer2
True
>>> writer4 = SummaryWriter.get_instance(
key="second", log_dir="log/test_sw/second")
>>> writer5 = SummaryWriter.get_instance(key="second")
>>> writer1 is not writer4
True
>>> writer4 is writer5
True
"""
_mutex_lock = threading.Lock()
_default_key: str
_instance: Optional[Dict[str, "SummaryWriter"]] = None
@classmethod
def get_instance(
cls,
key: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> "SummaryWriter":
"""Get instance of torch.utils.tensorboard.SummaryWriter by key."""
with SummaryWriter._mutex_lock:
if key is None:
key = SummaryWriter._default_key
if SummaryWriter._instance is None:
SummaryWriter._instance = {}
SummaryWriter._default_key = key
if key not in SummaryWriter._instance.keys():
SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs)
return SummaryWriter._instance[key]