fix #36
This commit is contained in:
parent
205698dd66
commit
8812eaa502
2
tianshou/env/atari.py
vendored
2
tianshou/env/atari.py
vendored
@ -60,7 +60,7 @@ class preprocessing(object):
|
||||
return np.stack([
|
||||
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
||||
|
||||
def render(self, mode):
|
||||
def render(self, mode='human'):
|
||||
return self.env.render(mode)
|
||||
|
||||
def step(self, action):
|
||||
|
@ -3,7 +3,8 @@ import numpy as np
|
||||
|
||||
|
||||
class MovAvg(object):
|
||||
"""Class for moving average. Usage:
|
||||
"""Class for moving average. It will automatically exclude the infinity and
|
||||
NaN. Usage:
|
||||
::
|
||||
|
||||
>>> stat = MovAvg(size=66)
|
||||
@ -22,19 +23,19 @@ class MovAvg(object):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.cache = []
|
||||
self.banned = [np.inf, np.nan, -np.inf]
|
||||
|
||||
def add(self, x):
|
||||
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||
only one element, a python scalar, or a list of python scalar. It will
|
||||
automatically exclude the infinity and NaN.
|
||||
only one element, a python scalar, or a list of python scalar.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.item()
|
||||
if isinstance(x, list):
|
||||
for _ in x:
|
||||
if _ not in [np.inf, np.nan, -np.inf]:
|
||||
if _ not in self.banned:
|
||||
self.cache.append(_)
|
||||
elif x != np.inf:
|
||||
elif x not in self.banned:
|
||||
self.cache.append(x)
|
||||
if self.size > 0 and len(self.cache) > self.size:
|
||||
self.cache = self.cache[-self.size:]
|
||||
|
Loading…
x
Reference in New Issue
Block a user