diff --git a/models.py b/models.py index 52a640e..a2448f1 100644 --- a/models.py +++ b/models.py @@ -366,11 +366,6 @@ class ImagBehavior(nn.Module): return imag_feat, imag_state, imag_action, weights, metrics def _imagine(self, start, policy, horizon, repeats=None): - # horizon: 15 - # start = dict(stoch, deter, logit) - # start["stoch"] (16, 63, 32, 32) - # start["deter"] (16, 63, 512) - # start["logit"] (16, 63, 32, 32) dynamics = self._world_model.dynamics if repeats: raise NotImplemented("repeats is not implemented in this version") @@ -385,12 +380,8 @@ class ImagBehavior(nn.Module): succ = dynamics.img_step(state, action, sample=self._config.imag_sample) return succ, feat, action - feat = 0 * dynamics.get_feat(start) - action = policy(feat).mode() - # Is this action deterministic or stochastic? - # action = policy(feat).sample() succ, feats, actions = tools.static_scan( - step, [torch.arange(horizon)], (start, feat, action) + step, [torch.arange(horizon)], (start, None, None) ) states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} if repeats: