add singleton pattern version of summary_writter (#230)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
parent
b364f1a26f
commit
c97aa4065e
@ -2,9 +2,10 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
from tianshou.utils import MovAvg
|
||||
from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils import SummaryWriter
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.discrete import DQN
|
||||
from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||
|
||||
|
||||
@ -62,7 +63,25 @@ def test_net():
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
|
||||
|
||||
def test_summary_writer():
|
||||
# get first instance by key of `default` or your own key
|
||||
writer1 = SummaryWriter.get_instance(
|
||||
key="first", log_dir="log/test_sw/first")
|
||||
assert writer1.log_dir == "log/test_sw/first"
|
||||
writer2 = SummaryWriter.get_instance()
|
||||
assert writer1 is writer2
|
||||
# create new instance by specify a new key
|
||||
writer3 = SummaryWriter.get_instance(
|
||||
key="second", log_dir="log/test_sw/second")
|
||||
assert writer3.log_dir == "log/test_sw/second"
|
||||
writer4 = SummaryWriter.get_instance(key="second")
|
||||
assert writer3 is writer4
|
||||
assert writer1 is not writer3
|
||||
assert writer1.log_dir != writer4.log_dir
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_noise()
|
||||
test_moving_average()
|
||||
test_net()
|
||||
test_summary_writer()
|
||||
|
@ -124,6 +124,7 @@ def test_dqn(args=get_args()):
|
||||
def test_pdqn(args=get_args()):
|
||||
args.prioritized_replay = 1
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_dqn(args)
|
||||
|
||||
|
||||
|
2
tianshou/env/worker/subproc.py
vendored
2
tianshou/env/worker/subproc.py
vendored
@ -32,7 +32,7 @@ class ShArray:
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(
|
||||
_NP_TO_CT[dtype.type], # type: ignore
|
||||
_NP_TO_CT[dtype.type],
|
||||
int(np.prod(shape)),
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
@ -178,10 +178,11 @@ class PPOPolicy(PGPolicy):
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters())
|
||||
+ list(self.critic.parameters()),
|
||||
self._max_grad_norm)
|
||||
if self._max_grad_norm:
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters())
|
||||
+ list(self.critic.parameters()),
|
||||
self._max_grad_norm)
|
||||
self.optim.step()
|
||||
return {
|
||||
"loss": losses,
|
||||
|
@ -1,7 +1,9 @@
|
||||
from tianshou.utils.config import tqdm_config
|
||||
from tianshou.utils.moving_average import MovAvg
|
||||
from tianshou.utils.log_tools import SummaryWriter
|
||||
|
||||
__all__ = [
|
||||
"MovAvg",
|
||||
"tqdm_config",
|
||||
"SummaryWriter",
|
||||
]
|
||||
|
47
tianshou/utils/log_tools.py
Normal file
47
tianshou/utils/log_tools.py
Normal file
@ -0,0 +1,47 @@
|
||||
import threading
|
||||
from torch.utils import tensorboard
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class SummaryWriter(tensorboard.SummaryWriter):
|
||||
"""A more convenient Summary Writer(`tensorboard.SummaryWriter`).
|
||||
|
||||
You can get the same instance of summary writer everywhere after you
|
||||
created one.
|
||||
::
|
||||
|
||||
>>> writer1 = SummaryWriter.get_instance(
|
||||
key="first", log_dir="log/test_sw/first")
|
||||
>>> writer2 = SummaryWriter.get_instance()
|
||||
>>> writer1 is writer2
|
||||
True
|
||||
>>> writer4 = SummaryWriter.get_instance(
|
||||
key="second", log_dir="log/test_sw/second")
|
||||
>>> writer5 = SummaryWriter.get_instance(key="second")
|
||||
>>> writer1 is not writer4
|
||||
True
|
||||
>>> writer4 is writer5
|
||||
True
|
||||
"""
|
||||
|
||||
_mutex_lock = threading.Lock()
|
||||
_default_key: str
|
||||
_instance: Optional[Dict[str, "SummaryWriter"]] = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
key: Optional[str] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> "SummaryWriter":
|
||||
"""Get instance of torch.utils.tensorboard.SummaryWriter by key."""
|
||||
with SummaryWriter._mutex_lock:
|
||||
if key is None:
|
||||
key = SummaryWriter._default_key
|
||||
if SummaryWriter._instance is None:
|
||||
SummaryWriter._instance = {}
|
||||
SummaryWriter._default_key = key
|
||||
if key not in SummaryWriter._instance.keys():
|
||||
SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs)
|
||||
return SummaryWriter._instance[key]
|
Loading…
x
Reference in New Issue
Block a user