make sure "is_first" is set 0 at beginning

This commit is contained in:
NM512 2023-07-22 21:08:53 +09:00
parent f07d843953
commit 03d91cb2c1

View File

@ -364,6 +364,8 @@ def sample_episodes(episodes, length, seed=0):
ret = {
k: v[index : min(index + length, total)] for k, v in episode.items()
}
if "is_first" in ret:
ret["is_first"][0] = True
else:
# 'is_first' comes after 'is_last'
index = 0
@ -374,6 +376,8 @@ def sample_episodes(episodes, length, seed=0):
)
for k, v in episode.items()
}
if "is_first" in ret:
ret["is_first"][size] = True
size = len(next(iter(ret.values())))
yield ret