Tianshou/tianshou/utils/statistics.py

112 lines
3.6 KiB
Python
Raw Normal View History

from numbers import Number
2020-03-12 22:20:33 +08:00
import numpy as np
import torch
2020-03-12 22:20:33 +08:00
class MovAvg:
"""Class for moving average.
It will automatically exclude the infinity and NaN. Usage:
2020-04-03 21:28:12 +08:00
::
>>> stat = MovAvg(size=66)
>>> stat.add(torch.tensor(5))
5.0
>>> stat.add(float('inf')) # which will not add to stat
5.0
>>> stat.add([6, 7, 8])
6.5
>>> stat.get()
6.5
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
6.50±1.12
"""
2020-05-12 11:31:47 +08:00
2020-05-16 20:08:32 +08:00
def __init__(self, size: int = 100) -> None:
2020-03-12 22:20:33 +08:00
super().__init__()
self.size = size
self.cache: list[np.number] = []
2020-04-23 22:06:18 +08:00
self.banned = [np.inf, np.nan, -np.inf]
2020-03-12 22:20:33 +08:00
def add(self, data_array: Number | np.number | list | np.ndarray | torch.Tensor) -> float:
"""Add a scalar into :class:`MovAvg`.
You can add ``torch.Tensor`` with only one element, a python scalar, or
a list of python scalar.
2020-04-03 21:28:12 +08:00
"""
if isinstance(data_array, torch.Tensor):
data_array = data_array.flatten().cpu().numpy()
if np.isscalar(data_array):
data_array = [data_array]
for number in data_array: # type: ignore
if number not in self.banned:
self.cache.append(number)
2020-03-12 22:20:33 +08:00
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size :]
2020-03-12 22:20:33 +08:00
return self.get()
def get(self) -> float:
2020-04-04 21:02:06 +08:00
"""Get the average."""
2020-03-12 22:20:33 +08:00
if len(self.cache) == 0:
return 0.0
return float(np.mean(self.cache)) # type: ignore
2020-03-15 17:41:00 +08:00
def mean(self) -> float:
2020-04-04 21:02:06 +08:00
"""Get the average. Same as :meth:`get`."""
2020-03-15 17:41:00 +08:00
return self.get()
def std(self) -> float:
2020-04-04 21:02:06 +08:00
"""Get the standard deviation."""
2020-03-15 17:41:00 +08:00
if len(self.cache) == 0:
return 0.0
return float(np.std(self.cache)) # type: ignore
class RunningMeanStd:
"""Calculates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
:param mean: the initial mean estimation for data array. Default to 0.
:param std: the initial standard error estimation for data array. Default to 1.
Remove kwargs in policy init (#950) Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 17:57:03 +02:00
:param clip_max: the maximum absolute value for data array. Default to
10.0.
Remove kwargs in policy init (#950) Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 17:57:03 +02:00
:param epsilon: To avoid division by zero.
"""
def __init__(
self,
mean: float | np.ndarray = 0.0,
std: float | np.ndarray = 1.0,
clip_max: float | None = 10.0,
epsilon: float = np.finfo(np.float32).eps.item(),
) -> None:
self.mean, self.var = mean, std
self.clip_max = clip_max
self.count = 0
self.eps = epsilon
def norm(self, data_array: float | np.ndarray) -> float | np.ndarray:
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
if self.clip_max:
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
return data_array
def update(self, data_array: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
batch_mean, batch_var = np.mean(data_array, axis=0), np.var(data_array, axis=0)
batch_count = len(data_array)
delta = batch_mean - self.mean
total_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count
new_var = m_2 / total_count
self.mean, self.var = new_mean, new_var
self.count = total_count