terminal value is along the sequence dim

This commit is contained in:
Aditya 2023-04-20 19:05:18 -07:00
parent 1e070a3daf
commit 52782d31e3

View File

@ -156,15 +156,15 @@ class Dreamer(nn.Module):
start = post
# start['deter'] (16, 64, 512)
if self._config.pred_discount: # Last step could be terminal.
start = {k: v[:-1] for k, v in post.items()}
context = {k: v[:-1] for k, v in context.items()}
start = {k: v[:, :-1] for k, v in post.items()}
context = {k: v[:, :-1] for k, v in context.items()}
reward = lambda f, s, a: self._wm.heads["reward"](
self._wm.dynamics.get_feat(s)
).mode()
metrics.update(self._task_behavior._train(start, reward)[-1])
if self._config.expl_behavior != "greedy":
if self._config.pred_discount:
data = {k: v[:-1] for k, v in data.items()}
data = {k: v[:, :-1] for k, v in data.items()}
mets = self._expl_behavior.train(start, context, data)[-1]
metrics.update({"expl_" + key: value for key, value in mets.items()})
for name, value in metrics.items():