erased unnecessary code
This commit is contained in:
parent
432a359bcf
commit
2a8b44eb0c
11
models.py
11
models.py
@ -366,11 +366,6 @@ class ImagBehavior(nn.Module):
|
||||
return imag_feat, imag_state, imag_action, weights, metrics
|
||||
|
||||
def _imagine(self, start, policy, horizon, repeats=None):
|
||||
# horizon: 15
|
||||
# start = dict(stoch, deter, logit)
|
||||
# start["stoch"] (16, 63, 32, 32)
|
||||
# start["deter"] (16, 63, 512)
|
||||
# start["logit"] (16, 63, 32, 32)
|
||||
dynamics = self._world_model.dynamics
|
||||
if repeats:
|
||||
raise NotImplemented("repeats is not implemented in this version")
|
||||
@ -385,12 +380,8 @@ class ImagBehavior(nn.Module):
|
||||
succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
|
||||
return succ, feat, action
|
||||
|
||||
feat = 0 * dynamics.get_feat(start)
|
||||
action = policy(feat).mode()
|
||||
# Is this action deterministic or stochastic?
|
||||
# action = policy(feat).sample()
|
||||
succ, feats, actions = tools.static_scan(
|
||||
step, [torch.arange(horizon)], (start, feat, action)
|
||||
step, [torch.arange(horizon)], (start, None, None)
|
||||
)
|
||||
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
|
||||
if repeats:
|
||||
|
Loading…
x
Reference in New Issue
Block a user