Merge pull request #6 from truncs/main
That is correct. Thanks for your contribution!
This commit is contained in:
commit
16151efb3c
@ -156,15 +156,15 @@ class Dreamer(nn.Module):
|
||||
start = post
|
||||
# start['deter'] (16, 64, 512)
|
||||
if self._config.pred_discount: # Last step could be terminal.
|
||||
start = {k: v[:-1] for k, v in post.items()}
|
||||
context = {k: v[:-1] for k, v in context.items()}
|
||||
start = {k: v[:, :-1] for k, v in post.items()}
|
||||
context = {k: v[:, :-1] for k, v in context.items()}
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||
self._wm.dynamics.get_feat(s)
|
||||
).mode()
|
||||
metrics.update(self._task_behavior._train(start, reward)[-1])
|
||||
if self._config.expl_behavior != "greedy":
|
||||
if self._config.pred_discount:
|
||||
data = {k: v[:-1] for k, v in data.items()}
|
||||
data = {k: v[:, :-1] for k, v in data.items()}
|
||||
mets = self._expl_behavior.train(start, context, data)[-1]
|
||||
metrics.update({"expl_" + key: value for key, value in mets.items()})
|
||||
for name, value in metrics.items():
|
||||
|
Loading…
x
Reference in New Issue
Block a user