diff --git a/networks.py b/networks.py index 9c58faf..3a767fa 100644 --- a/networks.py +++ b/networks.py @@ -215,7 +215,7 @@ class RSSM(nn.Module): is_first, is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)), ) - val = val * (1.0 - is_first_r) + init_state[key] * is_first_r + prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: