From a358a44a53cf4e605a99b021abaad2e92ac284bc Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 29 Oct 2025 10:39:15 -0700 Subject: [PATCH] always store old agent embeds and old action parameters when possible --- dreamer4/dreamer4.py | 11 ++++++----- pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c5b5e15..8e9b8f0 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2225,8 +2225,8 @@ class DynamicsWorldModel(Module): max_timesteps = 16, env_is_vectorized = False, use_time_kv_cache = True, - store_agent_embed = False, - store_old_action_unembeds = False, + store_agent_embed = True, + store_old_action_unembeds = True, ): assert exists(self.video_tokenizer) @@ -2391,7 +2391,7 @@ class DynamicsWorldModel(Module): actions = (discrete_actions, continuous_actions), log_probs = (discrete_log_probs, continuous_log_probs), values = values, - old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None, + old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None, agent_embed = acc_agent_embed if store_agent_embed else None, step_size = step_size, agent_index = agent_index, @@ -2667,7 +2667,8 @@ class DynamicsWorldModel(Module): return_agent_actions = False, return_log_probs_and_values = False, return_time_kv_cache = False, - store_agent_embed = False + store_agent_embed = True, + store_old_action_unembeds = True ): # (b t n d) | (b c t h w) @@ -2947,7 +2948,7 @@ class DynamicsWorldModel(Module): video = video, proprio = proprio if has_proprio else None, agent_embed = acc_agent_embed if store_agent_embed else None, - old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None, + old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if exists(acc_policy_embed) and store_old_action_unembeds else None, step_size = step_size, agent_index = agent_index, lens = experience_lens, diff --git a/pyproject.toml b/pyproject.toml index 5086d51..57fce7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.93" +version = "0.0.94" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }