diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 18e22f8..091be87 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1519,7 +1519,7 @@ class DynamicsWorldModel(Module): assert self.num_tasks > 0 task_embeds = self.task_embed(tasks) - agent_tokens = agent_tokens + task_embeds + agent_tokens = einx.add('b ... d, b d', agent_tokens, task_embeds) # maybe evolution