Fixed RSSM trajectory starts initialization.

This commit is contained in:
ktolnos 2023-06-05 22:49:40 +03:00
parent ac8b846505
commit 308eeff09a

View File

@ -215,7 +215,7 @@ class RSSM(nn.Module):
is_first,
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
)
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared: