terminal value is along the sequence dim
This commit is contained in:
parent
1e070a3daf
commit
52782d31e3
@ -156,15 +156,15 @@ class Dreamer(nn.Module):
|
|||||||
start = post
|
start = post
|
||||||
# start['deter'] (16, 64, 512)
|
# start['deter'] (16, 64, 512)
|
||||||
if self._config.pred_discount: # Last step could be terminal.
|
if self._config.pred_discount: # Last step could be terminal.
|
||||||
start = {k: v[:-1] for k, v in post.items()}
|
start = {k: v[:, :-1] for k, v in post.items()}
|
||||||
context = {k: v[:-1] for k, v in context.items()}
|
context = {k: v[:, :-1] for k, v in context.items()}
|
||||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||||
self._wm.dynamics.get_feat(s)
|
self._wm.dynamics.get_feat(s)
|
||||||
).mode()
|
).mode()
|
||||||
metrics.update(self._task_behavior._train(start, reward)[-1])
|
metrics.update(self._task_behavior._train(start, reward)[-1])
|
||||||
if self._config.expl_behavior != "greedy":
|
if self._config.expl_behavior != "greedy":
|
||||||
if self._config.pred_discount:
|
if self._config.pred_discount:
|
||||||
data = {k: v[:-1] for k, v in data.items()}
|
data = {k: v[:, :-1] for k, v in data.items()}
|
||||||
mets = self._expl_behavior.train(start, context, data)[-1]
|
mets = self._expl_behavior.train(start, context, data)[-1]
|
||||||
metrics.update({"expl_" + key: value for key, value in mets.items()})
|
metrics.update({"expl_" + key: value for key, value in mets.items()})
|
||||||
for name, value in metrics.items():
|
for name, value in metrics.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user