From 12cccd847568372ae7d544c8e9c733df392f813d Mon Sep 17 00:00:00 2001 From: NM512 Date: Sat, 29 Apr 2023 07:34:27 +0900 Subject: [PATCH] addition of "is_first" and "is_terminal" for envs --- envs/atari.py | 9 +++++++-- envs/dmc.py | 6 ++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/envs/atari.py b/envs/atari.py index e99fa20..025ca2d 100644 --- a/envs/atari.py +++ b/envs/atari.py @@ -98,7 +98,7 @@ class Atari: if not self._repeat: self._buffer[1][:] = self._buffer[0][:] self._screen(self._buffer[0]) - self._done = over or (self._length and self._step >= self._length) or dead + self._done = over or (self._length and self._step >= self._length) return self._obs( total, is_last=self._done or (dead and self._lives == "reset"), @@ -137,7 +137,12 @@ class Atari: weights = [0.299, 0.587, 1 - (0.299 + 0.587)] image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype) image = image[:, :, None] - return {"image": image, "is_terminal": is_terminal}, reward, is_last, {} + return ( + {"image": image, "is_terminal": is_terminal, "is_first": is_first}, + reward, + is_last, + {}, + ) def _screen(self, array): self._ale.getScreenRGB2(array) diff --git a/envs/dmc.py b/envs/dmc.py index efffdf3..a39b29e 100644 --- a/envs/dmc.py +++ b/envs/dmc.py @@ -44,7 +44,8 @@ class DeepMindControl: obs = dict(time_step.observation) obs["image"] = self.render() # There is no terminal state in DMC - obs["is_terminal"] = False + obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 + obs["is_first"] = time_step.first() done = time_step.last() info = {"discount": np.array(time_step.discount, np.float32)} return obs, reward, done, info @@ -53,7 +54,8 @@ class DeepMindControl: time_step = self._env.reset() obs = dict(time_step.observation) obs["image"] = self.render() - obs["is_terminal"] = False + obs["is_terminal"] = False if time_step.first() else time_step.discount == 0 + obs["is_first"] = time_step.first() return obs def render(self, *args, **kwargs):