From 46f86cd2478f6eb05c40648c2d5f856a125bdb85 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 28 Oct 2025 09:36:58 -0700 Subject: [PATCH] fix storing of agent embedding --- dreamer4/dreamer4.py | 8 +++++--- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 87ca6a1..091e4d9 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2255,7 +2255,7 @@ class DynamicsWorldModel(Module): video = cat((video, next_frame), dim = 2) rewards = safe_cat((rewards, reward), dim = 1) - acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1) + acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1) # package up one experience for learning @@ -2397,7 +2397,7 @@ class DynamicsWorldModel(Module): return_intermediates = True ) - agent_embeds = agent_embeds[..., agent_index, :] + agent_embeds = agent_embeds[..., agent_index, :] # maybe detach agent embed @@ -2672,7 +2672,9 @@ class DynamicsWorldModel(Module): # maybe store agent embed - acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1) + if store_agent_embed: + one_agent_embed = agent_embed[:, -1:, agent_index] + acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1) # decode the agent actions if needed diff --git a/pyproject.toml b/pyproject.toml index f7d03d1..c4328f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.87" +version = "0.0.88" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }