addition of "is_first" and "is_terminal" for envs
This commit is contained in:
parent
3d0e2c8b5a
commit
12cccd8475
@ -98,7 +98,7 @@ class Atari:
|
||||
if not self._repeat:
|
||||
self._buffer[1][:] = self._buffer[0][:]
|
||||
self._screen(self._buffer[0])
|
||||
self._done = over or (self._length and self._step >= self._length) or dead
|
||||
self._done = over or (self._length and self._step >= self._length)
|
||||
return self._obs(
|
||||
total,
|
||||
is_last=self._done or (dead and self._lives == "reset"),
|
||||
@ -137,7 +137,12 @@ class Atari:
|
||||
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
|
||||
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
|
||||
image = image[:, :, None]
|
||||
return {"image": image, "is_terminal": is_terminal}, reward, is_last, {}
|
||||
return (
|
||||
{"image": image, "is_terminal": is_terminal, "is_first": is_first},
|
||||
reward,
|
||||
is_last,
|
||||
{},
|
||||
)
|
||||
|
||||
def _screen(self, array):
|
||||
self._ale.getScreenRGB2(array)
|
||||
|
@ -44,7 +44,8 @@ class DeepMindControl:
|
||||
obs = dict(time_step.observation)
|
||||
obs["image"] = self.render()
|
||||
# There is no terminal state in DMC
|
||||
obs["is_terminal"] = False
|
||||
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||
obs["is_first"] = time_step.first()
|
||||
done = time_step.last()
|
||||
info = {"discount": np.array(time_step.discount, np.float32)}
|
||||
return obs, reward, done, info
|
||||
@ -53,7 +54,8 @@ class DeepMindControl:
|
||||
time_step = self._env.reset()
|
||||
obs = dict(time_step.observation)
|
||||
obs["image"] = self.render()
|
||||
obs["is_terminal"] = False
|
||||
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||
obs["is_first"] = time_step.first()
|
||||
return obs
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user