From 8812eaa5020ff3c7c84322cfccbf1c99d7210fc8 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 23 Apr 2020 22:06:18 +0800 Subject: [PATCH] fix #36 --- tianshou/env/atari.py | 2 +- tianshou/utils/moving_average.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tianshou/env/atari.py b/tianshou/env/atari.py index 3f67da0..9904f08 100644 --- a/tianshou/env/atari.py +++ b/tianshou/env/atari.py @@ -60,7 +60,7 @@ class preprocessing(object): return np.stack([ self._pool_and_resize() for _ in range(self.frame_skip)], axis=-1) - def render(self, mode): + def render(self, mode='human'): return self.env.render(mode) def step(self, action): diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index 035c5e2..3aefb23 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -3,7 +3,8 @@ import numpy as np class MovAvg(object): - """Class for moving average. Usage: + """Class for moving average. It will automatically exclude the infinity and + NaN. Usage: :: >>> stat = MovAvg(size=66) @@ -22,19 +23,19 @@ class MovAvg(object): super().__init__() self.size = size self.cache = [] + self.banned = [np.inf, np.nan, -np.inf] def add(self, x): """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with - only one element, a python scalar, or a list of python scalar. It will - automatically exclude the infinity and NaN. + only one element, a python scalar, or a list of python scalar. """ if isinstance(x, torch.Tensor): x = x.item() if isinstance(x, list): for _ in x: - if _ not in [np.inf, np.nan, -np.inf]: + if _ not in self.banned: self.cache.append(_) - elif x != np.inf: + elif x not in self.banned: self.cache.append(x) if self.size > 0 and len(self.cache) > self.size: self.cache = self.cache[-self.size:]