2020-03-14 21:48:31 +08:00
|
|
|
import torch
|
2020-03-12 22:20:33 +08:00
|
|
|
import numpy as np
|
2020-05-16 20:08:32 +08:00
|
|
|
from typing import Union
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-06-02 22:29:50 +08:00
|
|
|
from tianshou.data import to_numpy
|
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
|
|
|
|
class MovAvg(object):
|
2020-04-23 22:06:18 +08:00
|
|
|
"""Class for moving average. It will automatically exclude the infinity and
|
|
|
|
NaN. Usage:
|
2020-04-03 21:28:12 +08:00
|
|
|
::
|
|
|
|
|
|
|
|
>>> stat = MovAvg(size=66)
|
|
|
|
>>> stat.add(torch.tensor(5))
|
|
|
|
5.0
|
|
|
|
>>> stat.add(float('inf')) # which will not add to stat
|
|
|
|
5.0
|
|
|
|
>>> stat.add([6, 7, 8])
|
|
|
|
6.5
|
|
|
|
>>> stat.get()
|
|
|
|
6.5
|
|
|
|
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
|
|
|
|
6.50±1.12
|
|
|
|
"""
|
2020-05-12 11:31:47 +08:00
|
|
|
|
2020-05-16 20:08:32 +08:00
|
|
|
def __init__(self, size: int = 100) -> None:
|
2020-03-12 22:20:33 +08:00
|
|
|
super().__init__()
|
|
|
|
self.size = size
|
|
|
|
self.cache = []
|
2020-04-23 22:06:18 +08:00
|
|
|
self.banned = [np.inf, np.nan, -np.inf]
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
2020-04-23 22:06:18 +08:00
|
|
|
only one element, a python scalar, or a list of python scalar.
|
2020-04-03 21:28:12 +08:00
|
|
|
"""
|
2020-03-14 21:48:31 +08:00
|
|
|
if isinstance(x, torch.Tensor):
|
2020-06-02 22:29:50 +08:00
|
|
|
x = to_numpy(x.flatten())
|
2020-05-12 11:31:47 +08:00
|
|
|
if isinstance(x, list) or isinstance(x, np.ndarray):
|
2020-03-17 11:37:31 +08:00
|
|
|
for _ in x:
|
2020-04-23 22:06:18 +08:00
|
|
|
if _ not in self.banned:
|
2020-03-17 11:37:31 +08:00
|
|
|
self.cache.append(_)
|
2020-04-23 22:06:18 +08:00
|
|
|
elif x not in self.banned:
|
2020-03-12 22:20:33 +08:00
|
|
|
self.cache.append(x)
|
|
|
|
if self.size > 0 and len(self.cache) > self.size:
|
|
|
|
self.cache = self.cache[-self.size:]
|
|
|
|
return self.get()
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def get(self) -> float:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Get the average."""
|
2020-03-12 22:20:33 +08:00
|
|
|
if len(self.cache) == 0:
|
|
|
|
return 0
|
|
|
|
return np.mean(self.cache)
|
2020-03-15 17:41:00 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def mean(self) -> float:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Get the average. Same as :meth:`get`."""
|
2020-03-15 17:41:00 +08:00
|
|
|
return self.get()
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def std(self) -> float:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Get the standard deviation."""
|
2020-03-15 17:41:00 +08:00
|
|
|
if len(self.cache) == 0:
|
|
|
|
return 0
|
|
|
|
return np.std(self.cache)
|