addition of "is_first" and "is_terminal" for envs

This commit is contained in:
NM512 2023-04-29 07:34:27 +09:00
parent 3d0e2c8b5a
commit 12cccd8475
2 changed files with 11 additions and 4 deletions

View File

@ -98,7 +98,7 @@ class Atari:
if not self._repeat:
self._buffer[1][:] = self._buffer[0][:]
self._screen(self._buffer[0])
self._done = over or (self._length and self._step >= self._length) or dead
self._done = over or (self._length and self._step >= self._length)
return self._obs(
total,
is_last=self._done or (dead and self._lives == "reset"),
@ -137,7 +137,12 @@ class Atari:
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
image = image[:, :, None]
return {"image": image, "is_terminal": is_terminal}, reward, is_last, {}
return (
{"image": image, "is_terminal": is_terminal, "is_first": is_first},
reward,
is_last,
{},
)
def _screen(self, array):
self._ale.getScreenRGB2(array)

View File

@ -44,7 +44,8 @@ class DeepMindControl:
obs = dict(time_step.observation)
obs["image"] = self.render()
# There is no terminal state in DMC
obs["is_terminal"] = False
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
obs["is_first"] = time_step.first()
done = time_step.last()
info = {"discount": np.array(time_step.discount, np.float32)}
return obs, reward, done, info
@ -53,7 +54,8 @@ class DeepMindControl:
time_step = self._env.reset()
obs = dict(time_step.observation)
obs["image"] = self.render()
obs["is_terminal"] = False
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
obs["is_first"] = time_step.first()
return obs
def render(self, *args, **kwargs):