fix #36
This commit is contained in:
parent
205698dd66
commit
8812eaa502
2
tianshou/env/atari.py
vendored
2
tianshou/env/atari.py
vendored
@ -60,7 +60,7 @@ class preprocessing(object):
|
|||||||
return np.stack([
|
return np.stack([
|
||||||
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1)
|
||||||
|
|
||||||
def render(self, mode):
|
def render(self, mode='human'):
|
||||||
return self.env.render(mode)
|
return self.env.render(mode)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
@ -3,7 +3,8 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class MovAvg(object):
|
class MovAvg(object):
|
||||||
"""Class for moving average. Usage:
|
"""Class for moving average. It will automatically exclude the infinity and
|
||||||
|
NaN. Usage:
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> stat = MovAvg(size=66)
|
>>> stat = MovAvg(size=66)
|
||||||
@ -22,19 +23,19 @@ class MovAvg(object):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.size = size
|
self.size = size
|
||||||
self.cache = []
|
self.cache = []
|
||||||
|
self.banned = [np.inf, np.nan, -np.inf]
|
||||||
|
|
||||||
def add(self, x):
|
def add(self, x):
|
||||||
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||||
only one element, a python scalar, or a list of python scalar. It will
|
only one element, a python scalar, or a list of python scalar.
|
||||||
automatically exclude the infinity and NaN.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
x = x.item()
|
x = x.item()
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
for _ in x:
|
for _ in x:
|
||||||
if _ not in [np.inf, np.nan, -np.inf]:
|
if _ not in self.banned:
|
||||||
self.cache.append(_)
|
self.cache.append(_)
|
||||||
elif x != np.inf:
|
elif x not in self.banned:
|
||||||
self.cache.append(x)
|
self.cache.append(x)
|
||||||
if self.size > 0 and len(self.cache) > self.size:
|
if self.size > 0 and len(self.cache) > self.size:
|
||||||
self.cache = self.cache[-self.size:]
|
self.cache = self.cache[-self.size:]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user