Tianshou/tianshou/utils/moving_average.py

60 lines
1.5 KiB
Python
Raw Normal View History

2020-03-14 21:48:31 +08:00
import torch
2020-03-12 22:20:33 +08:00
import numpy as np
class MovAvg(object):
2020-04-03 21:28:12 +08:00
"""
Class for moving average. Usage:
::
>>> stat = MovAvg(size=66)
>>> stat.add(torch.tensor(5))
5.0
>>> stat.add(float('inf')) # which will not add to stat
5.0
>>> stat.add([6, 7, 8])
6.5
>>> stat.get()
6.5
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
6.50±1.12
"""
2020-03-12 22:20:33 +08:00
def __init__(self, size=100):
super().__init__()
self.size = size
self.cache = []
def add(self, x):
2020-04-03 21:28:12 +08:00
"""
Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
only one element, a python scalar, or a list of python scalar. It will
exclude the infinity.
"""
2020-03-14 21:48:31 +08:00
if isinstance(x, torch.Tensor):
2020-04-03 21:28:12 +08:00
x = x.item()
2020-03-17 11:37:31 +08:00
if isinstance(x, list):
for _ in x:
if _ != np.inf:
self.cache.append(_)
elif x != np.inf:
2020-03-12 22:20:33 +08:00
self.cache.append(x)
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:]
return self.get()
def get(self):
2020-04-04 21:02:06 +08:00
"""Get the average."""
2020-03-12 22:20:33 +08:00
if len(self.cache) == 0:
return 0
return np.mean(self.cache)
2020-03-15 17:41:00 +08:00
def mean(self):
2020-04-04 21:02:06 +08:00
"""Get the average. Same as :meth:`get`."""
2020-03-15 17:41:00 +08:00
return self.get()
def std(self):
2020-04-04 21:02:06 +08:00
"""Get the standard deviation."""
2020-03-15 17:41:00 +08:00
if len(self.cache) == 0:
return 0
return np.std(self.cache)