bug fix for gym==0.19.0
This commit is contained in:
parent
d3156ecb06
commit
b8ef214efa
@ -68,7 +68,6 @@ class Atari:
|
|||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
img_shape = self._size + ((1,) if self._gray else (3,))
|
img_shape = self._size + ((1,) if self._gray else (3,))
|
||||||
print(self._env.observation_space)
|
|
||||||
return gym.spaces.Dict(
|
return gym.spaces.Dict(
|
||||||
{
|
{
|
||||||
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
|
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
|
||||||
|
@ -180,7 +180,7 @@ class RewardObs:
|
|||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
spaces = self._env.observation_space.spaces
|
spaces = self._env.observation_space.spaces
|
||||||
assert "reward" not in spaces
|
assert "reward" not in spaces
|
||||||
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
|
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32)
|
||||||
return gym.spaces.Dict(spaces)
|
return gym.spaces.Dict(spaces)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
@ -343,6 +343,8 @@ class MultiEncoder(nn.Module):
|
|||||||
symlog_inputs,
|
symlog_inputs,
|
||||||
):
|
):
|
||||||
super(MultiEncoder, self).__init__()
|
super(MultiEncoder, self).__init__()
|
||||||
|
excluded = ("is_first", "is_last", "is_terminal", "reward")
|
||||||
|
shapes = {k: v for k, v in shapes.items() if k not in excluded}
|
||||||
self.cnn_shapes = {
|
self.cnn_shapes = {
|
||||||
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||||
}
|
}
|
||||||
@ -402,6 +404,8 @@ class MultiDecoder(nn.Module):
|
|||||||
vector_dist,
|
vector_dist,
|
||||||
):
|
):
|
||||||
super(MultiDecoder, self).__init__()
|
super(MultiDecoder, self).__init__()
|
||||||
|
excluded = ("is_first", "is_last", "is_terminal", "reward")
|
||||||
|
shapes = {k: v for k, v in shapes.items() if k not in excluded}
|
||||||
self.cnn_shapes = {
|
self.cnn_shapes = {
|
||||||
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user