diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 87ca6a1..091e4d9 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2255,7 +2255,7 @@ class DynamicsWorldModel(Module): video = cat((video, next_frame), dim = 2) rewards = safe_cat((rewards, reward), dim = 1) - acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1) + acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1) # package up one experience for learning @@ -2397,7 +2397,7 @@ class DynamicsWorldModel(Module): return_intermediates = True ) - agent_embeds = agent_embeds[..., agent_index, :] + agent_embeds = agent_embeds[..., agent_index, :] # maybe detach agent embed @@ -2672,7 +2672,9 @@ class DynamicsWorldModel(Module): # maybe store agent embed - acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1) + if store_agent_embed: + one_agent_embed = agent_embed[:, -1:, agent_index] + acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1) # decode the agent actions if needed diff --git a/pyproject.toml b/pyproject.toml index f7d03d1..c4328f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.87" +version = "0.0.88" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }