From 52782d31e3e249d6f1034e046bae3a18f0b05ef3 Mon Sep 17 00:00:00 2001 From: Aditya <157234+truncs@users.noreply.github.com> Date: Thu, 20 Apr 2023 19:05:18 -0700 Subject: [PATCH] terminal value is along the sequence dim --- dreamer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dreamer.py b/dreamer.py index 71e7437..796051b 100644 --- a/dreamer.py +++ b/dreamer.py @@ -156,15 +156,15 @@ class Dreamer(nn.Module): start = post # start['deter'] (16, 64, 512) if self._config.pred_discount: # Last step could be terminal. - start = {k: v[:-1] for k, v in post.items()} - context = {k: v[:-1] for k, v in context.items()} + start = {k: v[:, :-1] for k, v in post.items()} + context = {k: v[:, :-1] for k, v in context.items()} reward = lambda f, s, a: self._wm.heads["reward"]( self._wm.dynamics.get_feat(s) ).mode() metrics.update(self._task_behavior._train(start, reward)[-1]) if self._config.expl_behavior != "greedy": if self._config.pred_discount: - data = {k: v[:-1] for k, v in data.items()} + data = {k: v[:, :-1] for k, v in data.items()} mets = self._expl_behavior.train(start, context, data)[-1] metrics.update({"expl_" + key: value for key, value in mets.items()}) for name, value in metrics.items():