Tianshou/tianshou/utils/statistics.py
n+e 09692c84fe
fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
2021-03-30 16:06:03 +08:00

96 lines
2.8 KiB
Python

import torch
import numpy as np
from numbers import Number
from typing import List, Union
class MovAvg(object):
"""Class for moving average.
It will automatically exclude the infinity and NaN. Usage:
::
>>> stat = MovAvg(size=66)
>>> stat.add(torch.tensor(5))
5.0
>>> stat.add(float('inf')) # which will not add to stat
5.0
>>> stat.add([6, 7, 8])
6.5
>>> stat.get()
6.5
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
6.50±1.12
"""
def __init__(self, size: int = 100) -> None:
super().__init__()
self.size = size
self.cache: List[np.number] = []
self.banned = [np.inf, np.nan, -np.inf]
def add(
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor]
) -> float:
"""Add a scalar into :class:`MovAvg`.
You can add ``torch.Tensor`` with only one element, a python scalar, or
a list of python scalar.
"""
if isinstance(x, torch.Tensor):
x = x.flatten().cpu().numpy()
if np.isscalar(x):
x = [x]
for i in x: # type: ignore
if i not in self.banned:
self.cache.append(i)
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:]
return self.get()
def get(self) -> float:
"""Get the average."""
if len(self.cache) == 0:
return 0.0
return float(np.mean(self.cache))
def mean(self) -> float:
"""Get the average. Same as :meth:`get`."""
return self.get()
def std(self) -> float:
"""Get the standard deviation."""
if len(self.cache) == 0:
return 0.0
return float(np.std(self.cache))
class RunningMeanStd(object):
"""Calulates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
"""
def __init__(
self, mean: Union[float, np.ndarray] = 0.0, std: Union[float, np.ndarray] = 1.0
) -> None:
self.mean, self.var = mean, std
self.count = 0
def update(self, x: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
batch_count = len(x)
delta = batch_mean - self.mean
total_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
new_var = m_2 / total_count
self.mean, self.var = new_mean, new_var # type: ignore
self.count = total_count