diff --git a/envs/atari.py b/envs/atari.py index 025ca2d..4d32c49 100644 --- a/envs/atari.py +++ b/envs/atari.py @@ -1,3 +1,4 @@ +import gym import numpy as np @@ -64,6 +65,16 @@ class Atari: self._done = True self._step = 0 + @property + def observation_space(self): + img_shape = self._size + ((1,) if self._gray else (3,)) + print(self._env.observation_space) + return gym.spaces.Dict( + { + "image": gym.spaces.Box(0, 255, img_shape, np.uint8), + } + ) + @property def action_space(self): space = self._env.action_space