From b8ef214efad695c2d17d1cb0e32f5723db0e1613 Mon Sep 17 00:00:00 2001 From: NM512 Date: Thu, 18 May 2023 21:30:08 +0900 Subject: [PATCH] bug fix for gym==0.19.0 --- envs/atari.py | 1 - envs/wrappers.py | 2 +- networks.py | 4 ++++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/envs/atari.py b/envs/atari.py index 4d32c49..31f7a74 100644 --- a/envs/atari.py +++ b/envs/atari.py @@ -68,7 +68,6 @@ class Atari: @property def observation_space(self): img_shape = self._size + ((1,) if self._gray else (3,)) - print(self._env.observation_space) return gym.spaces.Dict( { "image": gym.spaces.Box(0, 255, img_shape, np.uint8), diff --git a/envs/wrappers.py b/envs/wrappers.py index a70611d..1a4a58b 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -180,7 +180,7 @@ class RewardObs: def observation_space(self): spaces = self._env.observation_space.spaces assert "reward" not in spaces - spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) + spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32) return gym.spaces.Dict(spaces) def step(self, action): diff --git a/networks.py b/networks.py index 8829cde..d122ec2 100644 --- a/networks.py +++ b/networks.py @@ -343,6 +343,8 @@ class MultiEncoder(nn.Module): symlog_inputs, ): super(MultiEncoder, self).__init__() + excluded = ("is_first", "is_last", "is_terminal", "reward") + shapes = {k: v for k, v in shapes.items() if k not in excluded} self.cnn_shapes = { k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k) } @@ -402,6 +404,8 @@ class MultiDecoder(nn.Module): vector_dist, ): super(MultiDecoder, self).__init__() + excluded = ("is_first", "is_last", "is_terminal", "reward") + shapes = {k: v for k, v in shapes.items() if k not in excluded} self.cnn_shapes = { k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k) }