Fixed RSSM trajectory starts initialization.
This commit is contained in:
parent
ac8b846505
commit
308eeff09a
@ -215,7 +215,7 @@ class RSSM(nn.Module):
|
|||||||
is_first,
|
is_first,
|
||||||
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
||||||
)
|
)
|
||||||
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||||
|
|
||||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||||
if self._shared:
|
if self._shared:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user