From c97aa4065ee8464bd5897bb86f1f81abd8e2cff9 Mon Sep 17 00:00:00 2001 From: rocknamx Date: Sat, 31 Oct 2020 16:38:54 +0800 Subject: [PATCH] add singleton pattern version of summary_writter (#230) Co-authored-by: Trinkle23897 --- test/base/test_utils.py | 21 +++++++++++++- test/discrete/test_dqn.py | 1 + tianshou/env/worker/subproc.py | 2 +- tianshou/policy/modelfree/ppo.py | 9 +++--- tianshou/utils/__init__.py | 2 ++ tianshou/utils/log_tools.py | 47 ++++++++++++++++++++++++++++++++ 6 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 tianshou/utils/log_tools.py diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 5944bfb..6057dfc 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2,9 +2,10 @@ import torch import numpy as np from tianshou.utils import MovAvg -from tianshou.exploration import GaussianNoise, OUNoise +from tianshou.utils import SummaryWriter from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import DQN +from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -62,7 +63,25 @@ def test_net(): assert list(net(data)[0].shape) == expect_output_shape +def test_summary_writer(): + # get first instance by key of `default` or your own key + writer1 = SummaryWriter.get_instance( + key="first", log_dir="log/test_sw/first") + assert writer1.log_dir == "log/test_sw/first" + writer2 = SummaryWriter.get_instance() + assert writer1 is writer2 + # create new instance by specify a new key + writer3 = SummaryWriter.get_instance( + key="second", log_dir="log/test_sw/second") + assert writer3.log_dir == "log/test_sw/second" + writer4 = SummaryWriter.get_instance(key="second") + assert writer3 is writer4 + assert writer1 is not writer3 + assert writer1.log_dir != writer4.log_dir + + if __name__ == '__main__': test_noise() test_moving_average() test_net() + test_summary_writer() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 7564c08..8ec9fa3 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -124,6 +124,7 @@ def test_dqn(args=get_args()): def test_pdqn(args=get_args()): args.prioritized_replay = 1 args.gamma = .95 + args.seed = 1 test_dqn(args) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index b578dd7..02acc21 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -32,7 +32,7 @@ class ShArray: def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: self.arr = Array( - _NP_TO_CT[dtype.type], # type: ignore + _NP_TO_CT[dtype.type], int(np.prod(shape)), ) self.dtype = dtype diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 60bb19e..5a04ec6 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -178,10 +178,11 @@ class PPOPolicy(PGPolicy): losses.append(loss.item()) self.optim.zero_grad() loss.backward() - nn.utils.clip_grad_norm_( - list(self.actor.parameters()) - + list(self.critic.parameters()), - self._max_grad_norm) + if self._max_grad_norm: + nn.utils.clip_grad_norm_( + list(self.actor.parameters()) + + list(self.critic.parameters()), + self._max_grad_norm) self.optim.step() return { "loss": losses, diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index e5827fd..d3a3715 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,7 +1,9 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.moving_average import MovAvg +from tianshou.utils.log_tools import SummaryWriter __all__ = [ "MovAvg", "tqdm_config", + "SummaryWriter", ] diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py new file mode 100644 index 0000000..bbbd82e --- /dev/null +++ b/tianshou/utils/log_tools.py @@ -0,0 +1,47 @@ +import threading +from torch.utils import tensorboard +from typing import Any, Dict, Optional + + +class SummaryWriter(tensorboard.SummaryWriter): + """A more convenient Summary Writer(`tensorboard.SummaryWriter`). + + You can get the same instance of summary writer everywhere after you + created one. + :: + + >>> writer1 = SummaryWriter.get_instance( + key="first", log_dir="log/test_sw/first") + >>> writer2 = SummaryWriter.get_instance() + >>> writer1 is writer2 + True + >>> writer4 = SummaryWriter.get_instance( + key="second", log_dir="log/test_sw/second") + >>> writer5 = SummaryWriter.get_instance(key="second") + >>> writer1 is not writer4 + True + >>> writer4 is writer5 + True + """ + + _mutex_lock = threading.Lock() + _default_key: str + _instance: Optional[Dict[str, "SummaryWriter"]] = None + + @classmethod + def get_instance( + cls, + key: Optional[str] = None, + *args: Any, + **kwargs: Any, + ) -> "SummaryWriter": + """Get instance of torch.utils.tensorboard.SummaryWriter by key.""" + with SummaryWriter._mutex_lock: + if key is None: + key = SummaryWriter._default_key + if SummaryWriter._instance is None: + SummaryWriter._instance = {} + SummaryWriter._default_key = key + if key not in SummaryWriter._instance.keys(): + SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs) + return SummaryWriter._instance[key]