diff --git a/tools.py b/tools.py index 45048f5..4fcfdb6 100644 --- a/tools.py +++ b/tools.py @@ -364,6 +364,8 @@ def sample_episodes(episodes, length, seed=0): ret = { k: v[index : min(index + length, total)] for k, v in episode.items() } + if "is_first" in ret: + ret["is_first"][0] = True else: # 'is_first' comes after 'is_last' index = 0 @@ -374,6 +376,8 @@ def sample_episodes(episodes, length, seed=0): ) for k, v in episode.items() } + if "is_first" in ret: + ret["is_first"][size] = True size = len(next(iter(ret.values()))) yield ret