addition of "is_first" and "is_terminal" for envs
This commit is contained in:
parent
3d0e2c8b5a
commit
12cccd8475
@ -98,7 +98,7 @@ class Atari:
|
|||||||
if not self._repeat:
|
if not self._repeat:
|
||||||
self._buffer[1][:] = self._buffer[0][:]
|
self._buffer[1][:] = self._buffer[0][:]
|
||||||
self._screen(self._buffer[0])
|
self._screen(self._buffer[0])
|
||||||
self._done = over or (self._length and self._step >= self._length) or dead
|
self._done = over or (self._length and self._step >= self._length)
|
||||||
return self._obs(
|
return self._obs(
|
||||||
total,
|
total,
|
||||||
is_last=self._done or (dead and self._lives == "reset"),
|
is_last=self._done or (dead and self._lives == "reset"),
|
||||||
@ -137,7 +137,12 @@ class Atari:
|
|||||||
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
|
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
|
||||||
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
|
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
|
||||||
image = image[:, :, None]
|
image = image[:, :, None]
|
||||||
return {"image": image, "is_terminal": is_terminal}, reward, is_last, {}
|
return (
|
||||||
|
{"image": image, "is_terminal": is_terminal, "is_first": is_first},
|
||||||
|
reward,
|
||||||
|
is_last,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
def _screen(self, array):
|
def _screen(self, array):
|
||||||
self._ale.getScreenRGB2(array)
|
self._ale.getScreenRGB2(array)
|
||||||
|
@ -44,7 +44,8 @@ class DeepMindControl:
|
|||||||
obs = dict(time_step.observation)
|
obs = dict(time_step.observation)
|
||||||
obs["image"] = self.render()
|
obs["image"] = self.render()
|
||||||
# There is no terminal state in DMC
|
# There is no terminal state in DMC
|
||||||
obs["is_terminal"] = False
|
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||||
|
obs["is_first"] = time_step.first()
|
||||||
done = time_step.last()
|
done = time_step.last()
|
||||||
info = {"discount": np.array(time_step.discount, np.float32)}
|
info = {"discount": np.array(time_step.discount, np.float32)}
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
@ -53,7 +54,8 @@ class DeepMindControl:
|
|||||||
time_step = self._env.reset()
|
time_step = self._env.reset()
|
||||||
obs = dict(time_step.observation)
|
obs = dict(time_step.observation)
|
||||||
obs["image"] = self.render()
|
obs["image"] = self.render()
|
||||||
obs["is_terminal"] = False
|
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||||
|
obs["is_first"] = time_step.first()
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def render(self, *args, **kwargs):
|
def render(self, *args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user