add singleton pattern version of summary_writter (#230)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
parent
b364f1a26f
commit
c97aa4065e
@ -2,9 +2,10 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
from tianshou.exploration import GaussianNoise, OUNoise
|
from tianshou.utils import SummaryWriter
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.net.discrete import DQN
|
from tianshou.utils.net.discrete import DQN
|
||||||
|
from tianshou.exploration import GaussianNoise, OUNoise
|
||||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +63,25 @@ def test_net():
|
|||||||
assert list(net(data)[0].shape) == expect_output_shape
|
assert list(net(data)[0].shape) == expect_output_shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_writer():
|
||||||
|
# get first instance by key of `default` or your own key
|
||||||
|
writer1 = SummaryWriter.get_instance(
|
||||||
|
key="first", log_dir="log/test_sw/first")
|
||||||
|
assert writer1.log_dir == "log/test_sw/first"
|
||||||
|
writer2 = SummaryWriter.get_instance()
|
||||||
|
assert writer1 is writer2
|
||||||
|
# create new instance by specify a new key
|
||||||
|
writer3 = SummaryWriter.get_instance(
|
||||||
|
key="second", log_dir="log/test_sw/second")
|
||||||
|
assert writer3.log_dir == "log/test_sw/second"
|
||||||
|
writer4 = SummaryWriter.get_instance(key="second")
|
||||||
|
assert writer3 is writer4
|
||||||
|
assert writer1 is not writer3
|
||||||
|
assert writer1.log_dir != writer4.log_dir
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_noise()
|
test_noise()
|
||||||
test_moving_average()
|
test_moving_average()
|
||||||
test_net()
|
test_net()
|
||||||
|
test_summary_writer()
|
||||||
|
@ -124,6 +124,7 @@ def test_dqn(args=get_args()):
|
|||||||
def test_pdqn(args=get_args()):
|
def test_pdqn(args=get_args()):
|
||||||
args.prioritized_replay = 1
|
args.prioritized_replay = 1
|
||||||
args.gamma = .95
|
args.gamma = .95
|
||||||
|
args.seed = 1
|
||||||
test_dqn(args)
|
test_dqn(args)
|
||||||
|
|
||||||
|
|
||||||
|
2
tianshou/env/worker/subproc.py
vendored
2
tianshou/env/worker/subproc.py
vendored
@ -32,7 +32,7 @@ class ShArray:
|
|||||||
|
|
||||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||||
self.arr = Array(
|
self.arr = Array(
|
||||||
_NP_TO_CT[dtype.type], # type: ignore
|
_NP_TO_CT[dtype.type],
|
||||||
int(np.prod(shape)),
|
int(np.prod(shape)),
|
||||||
)
|
)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
@ -178,6 +178,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
if self._max_grad_norm:
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
list(self.actor.parameters())
|
list(self.actor.parameters())
|
||||||
+ list(self.critic.parameters()),
|
+ list(self.critic.parameters()),
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from tianshou.utils.config import tqdm_config
|
from tianshou.utils.config import tqdm_config
|
||||||
from tianshou.utils.moving_average import MovAvg
|
from tianshou.utils.moving_average import MovAvg
|
||||||
|
from tianshou.utils.log_tools import SummaryWriter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MovAvg",
|
"MovAvg",
|
||||||
"tqdm_config",
|
"tqdm_config",
|
||||||
|
"SummaryWriter",
|
||||||
]
|
]
|
||||||
|
47
tianshou/utils/log_tools.py
Normal file
47
tianshou/utils/log_tools.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import threading
|
||||||
|
from torch.utils import tensorboard
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryWriter(tensorboard.SummaryWriter):
|
||||||
|
"""A more convenient Summary Writer(`tensorboard.SummaryWriter`).
|
||||||
|
|
||||||
|
You can get the same instance of summary writer everywhere after you
|
||||||
|
created one.
|
||||||
|
::
|
||||||
|
|
||||||
|
>>> writer1 = SummaryWriter.get_instance(
|
||||||
|
key="first", log_dir="log/test_sw/first")
|
||||||
|
>>> writer2 = SummaryWriter.get_instance()
|
||||||
|
>>> writer1 is writer2
|
||||||
|
True
|
||||||
|
>>> writer4 = SummaryWriter.get_instance(
|
||||||
|
key="second", log_dir="log/test_sw/second")
|
||||||
|
>>> writer5 = SummaryWriter.get_instance(key="second")
|
||||||
|
>>> writer1 is not writer4
|
||||||
|
True
|
||||||
|
>>> writer4 is writer5
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
|
||||||
|
_mutex_lock = threading.Lock()
|
||||||
|
_default_key: str
|
||||||
|
_instance: Optional[Dict[str, "SummaryWriter"]] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(
|
||||||
|
cls,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "SummaryWriter":
|
||||||
|
"""Get instance of torch.utils.tensorboard.SummaryWriter by key."""
|
||||||
|
with SummaryWriter._mutex_lock:
|
||||||
|
if key is None:
|
||||||
|
key = SummaryWriter._default_key
|
||||||
|
if SummaryWriter._instance is None:
|
||||||
|
SummaryWriter._instance = {}
|
||||||
|
SummaryWriter._default_key = key
|
||||||
|
if key not in SummaryWriter._instance.keys():
|
||||||
|
SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs)
|
||||||
|
return SummaryWriter._instance[key]
|
Loading…
x
Reference in New Issue
Block a user