2020-09-12 15:39:01 +08:00
|
|
|
from numbers import Number
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2021-09-03 05:05:04 +08:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
class MovAvg:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Class for moving average.
|
|
|
|
|
|
|
|
It will automatically exclude the infinity and NaN. Usage:
|
2020-04-03 21:28:12 +08:00
|
|
|
::
|
|
|
|
|
|
|
|
>>> stat = MovAvg(size=66)
|
|
|
|
>>> stat.add(torch.tensor(5))
|
|
|
|
5.0
|
|
|
|
>>> stat.add(float('inf')) # which will not add to stat
|
|
|
|
5.0
|
|
|
|
>>> stat.add([6, 7, 8])
|
|
|
|
6.5
|
|
|
|
>>> stat.get()
|
|
|
|
6.5
|
|
|
|
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
|
|
|
|
6.50±1.12
|
|
|
|
"""
|
2020-05-12 11:31:47 +08:00
|
|
|
|
2020-05-16 20:08:32 +08:00
|
|
|
def __init__(self, size: int = 100) -> None:
|
2020-03-12 22:20:33 +08:00
|
|
|
super().__init__()
|
|
|
|
self.size = size
|
2023-08-25 23:40:56 +02:00
|
|
|
self.cache: list[np.number] = []
|
2020-04-23 22:06:18 +08:00
|
|
|
self.banned = [np.inf, np.nan, -np.inf]
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2023-09-05 23:34:23 +02:00
|
|
|
def add(self, data_array: Number | np.number | list | np.ndarray | torch.Tensor) -> float:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Add a scalar into :class:`MovAvg`.
|
|
|
|
|
|
|
|
You can add ``torch.Tensor`` with only one element, a python scalar, or
|
|
|
|
a list of python scalar.
|
2020-04-03 21:28:12 +08:00
|
|
|
"""
|
2022-01-30 00:53:56 +08:00
|
|
|
if isinstance(data_array, torch.Tensor):
|
|
|
|
data_array = data_array.flatten().cpu().numpy()
|
|
|
|
if np.isscalar(data_array):
|
|
|
|
data_array = [data_array]
|
|
|
|
for number in data_array: # type: ignore
|
|
|
|
if number not in self.banned:
|
|
|
|
self.cache.append(number)
|
2020-03-12 22:20:33 +08:00
|
|
|
if self.size > 0 and len(self.cache) > self.size:
|
2023-08-25 23:40:56 +02:00
|
|
|
self.cache = self.cache[-self.size :]
|
2020-03-12 22:20:33 +08:00
|
|
|
return self.get()
|
|
|
|
|
2021-03-30 16:06:03 +08:00
|
|
|
def get(self) -> float:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Get the average."""
|
2020-03-12 22:20:33 +08:00
|
|
|
if len(self.cache) == 0:
|
2021-03-30 16:06:03 +08:00
|
|
|
return 0.0
|
2022-09-26 18:31:23 +02:00
|
|
|
return float(np.mean(self.cache)) # type: ignore
|
2020-03-15 17:41:00 +08:00
|
|
|
|
2021-03-30 16:06:03 +08:00
|
|
|
def mean(self) -> float:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Get the average. Same as :meth:`get`."""
|
2020-03-15 17:41:00 +08:00
|
|
|
return self.get()
|
|
|
|
|
2021-03-30 16:06:03 +08:00
|
|
|
def std(self) -> float:
|
2020-04-04 21:02:06 +08:00
|
|
|
"""Get the standard deviation."""
|
2020-03-15 17:41:00 +08:00
|
|
|
if len(self.cache) == 0:
|
2021-03-30 16:06:03 +08:00
|
|
|
return 0.0
|
2022-09-26 18:31:23 +02:00
|
|
|
return float(np.std(self.cache)) # type: ignore
|
2021-03-11 20:50:20 +08:00
|
|
|
|
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
class RunningMeanStd:
|
2021-09-03 05:05:04 +08:00
|
|
|
"""Calculates the running mean and std of a data stream.
|
2021-03-11 20:50:20 +08:00
|
|
|
|
|
|
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
2022-07-14 22:52:56 -07:00
|
|
|
|
|
|
|
:param mean: the initial mean estimation for data array. Default to 0.
|
|
|
|
:param std: the initial standard error estimation for data array. Default to 1.
|
2023-10-08 17:57:03 +02:00
|
|
|
:param clip_max: the maximum absolute value for data array. Default to
|
2022-07-14 22:52:56 -07:00
|
|
|
10.0.
|
2023-10-08 17:57:03 +02:00
|
|
|
:param epsilon: To avoid division by zero.
|
2021-03-11 20:50:20 +08:00
|
|
|
"""
|
|
|
|
|
2021-03-30 16:06:03 +08:00
|
|
|
def __init__(
|
2021-09-03 05:05:04 +08:00
|
|
|
self,
|
2023-09-05 23:34:23 +02:00
|
|
|
mean: float | np.ndarray = 0.0,
|
|
|
|
std: float | np.ndarray = 1.0,
|
|
|
|
clip_max: float | None = 10.0,
|
2022-07-14 22:52:56 -07:00
|
|
|
epsilon: float = np.finfo(np.float32).eps.item(),
|
2021-03-30 16:06:03 +08:00
|
|
|
) -> None:
|
|
|
|
self.mean, self.var = mean, std
|
2022-07-14 22:52:56 -07:00
|
|
|
self.clip_max = clip_max
|
2021-03-11 20:50:20 +08:00
|
|
|
self.count = 0
|
2022-07-14 22:52:56 -07:00
|
|
|
self.eps = epsilon
|
|
|
|
|
2023-09-05 23:34:23 +02:00
|
|
|
def norm(self, data_array: float | np.ndarray) -> float | np.ndarray:
|
2022-07-14 22:52:56 -07:00
|
|
|
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
|
|
|
|
if self.clip_max:
|
|
|
|
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
|
|
|
|
return data_array
|
2021-03-11 20:50:20 +08:00
|
|
|
|
2022-01-30 00:53:56 +08:00
|
|
|
def update(self, data_array: np.ndarray) -> None:
|
2021-03-11 20:50:20 +08:00
|
|
|
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
|
2022-01-30 00:53:56 +08:00
|
|
|
batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0)
|
|
|
|
batch_count = len(data_array)
|
2021-03-11 20:50:20 +08:00
|
|
|
|
|
|
|
delta = batch_mean - self.mean
|
|
|
|
total_count = self.count + batch_count
|
|
|
|
|
|
|
|
new_mean = self.mean + delta * batch_count / total_count
|
|
|
|
m_a = self.var * self.count
|
|
|
|
m_b = batch_var * batch_count
|
2021-09-03 05:05:04 +08:00
|
|
|
m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count
|
2021-03-11 20:50:20 +08:00
|
|
|
new_var = m_2 / total_count
|
|
|
|
|
2021-05-11 18:24:48 -07:00
|
|
|
self.mean, self.var = new_mean, new_var
|
2021-03-11 20:50:20 +08:00
|
|
|
self.count = total_count
|