diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 3f86e9e..c60ae67 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2762,7 +2762,7 @@ class DynamicsWorldModel(Module): ): with world_model_forward_context(): - _, (agent_embeds, _) = self.forward( + _, (embeds, _) = self.forward( latents = latents, signal_levels = self.max_steps - 1, step_sizes = step_size, @@ -2774,7 +2774,7 @@ class DynamicsWorldModel(Module): return_intermediates = True ) - agent_embeds = agent_embeds[..., agent_index, :] + agent_embeds = embeds.agent[..., agent_index, :] # maybe detach agent embed diff --git a/pyproject.toml b/pyproject.toml index 8499691..5dca4f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.22" +version = "0.1.23" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }