From d476fa7b14caaa95c62b1b2767acd200dd7de403 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 28 Oct 2025 09:02:26 -0700 Subject: [PATCH] able to store the agent embeddings during rollouts with imagination or environment, for efficient policy optimization (but will also allow for finetuning world model for the heads) --- dreamer4/dreamer4.py | 19 ++++++++++++++++++- pyproject.toml | 2 +- tests/test_dreamer.py | 11 ++++++++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 91c4299..153a13e 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -77,6 +77,7 @@ class Experience: latents: Tensor video: Tensor | None = None proprio: Tensor | None = None + agent_embed: Tensor | None = None, rewards: Tensor | None = None actions: tuple[Tensor, Tensor] | None = None log_probs: tuple[Tensor, Tensor] | None = None @@ -2105,7 +2106,8 @@ class DynamicsWorldModel(Module): step_size = 4, max_timesteps = 16, env_is_vectorized = False, - use_time_kv_cache = True + use_time_kv_cache = True, + store_agent_embed = False ): assert exists(self.video_tokenizer) @@ -2130,6 +2132,8 @@ class DynamicsWorldModel(Module): values = None latents = None + acc_agent_embed = None + # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env is_terminated = full((batch,), False, device = device) @@ -2251,6 +2255,8 @@ class DynamicsWorldModel(Module): video = cat((video, next_frame), dim = 2) rewards = safe_cat((rewards, reward), dim = 1) + acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1) + # package up one experience for learning batch, device = latents.shape[0], latents.device @@ -2262,6 +2268,7 @@ class DynamicsWorldModel(Module): actions = (discrete_actions, continuous_actions), log_probs = (discrete_log_probs, continuous_log_probs), values = values, + agent_embed = acc_agent_embed if store_agent_embed else None, step_size = step_size, agent_index = agent_index, is_truncated = is_truncated, @@ -2491,6 +2498,7 @@ class DynamicsWorldModel(Module): return_agent_actions = False, return_log_probs_and_values = False, return_time_kv_cache = False, + store_agent_embed = False ): # (b t n d) | (b c t h w) @@ -2543,6 +2551,10 @@ class DynamicsWorldModel(Module): decoded_continuous_log_probs = None decoded_values = None + # maybe store agent embed + + acc_agent_embed = None + # maybe return rewards decoded_rewards = None @@ -2651,6 +2663,10 @@ class DynamicsWorldModel(Module): decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) + # maybe store agent embed + + acc_agent_embed = safe_cat((acc_agent_embed, agent_embed), dim = 1) + # decode the agent actions if needed if return_agent_actions: @@ -2747,6 +2763,7 @@ class DynamicsWorldModel(Module): latents = latents, video = video, proprio = proprio if has_proprio else None, + agent_embed = acc_agent_embed if store_agent_embed else None, step_size = step_size, agent_index = agent_index, lens = experience_lens, diff --git a/pyproject.toml b/pyproject.toml index 4ce3fa3..1448b38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.83" +version = "0.0.85" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 0be14df..518e8e3 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -614,11 +614,13 @@ def test_cache_generate(): @param('use_signed_advantage', (False, True)) @param('env_can_terminate', (False, True)) @param('env_can_truncate', (False, True)) +@param('store_agent_embed', (False, True)) def test_online_rl( vectorized, use_signed_advantage, env_can_terminate, - env_can_truncate + env_can_truncate, + store_agent_embed ): from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer @@ -664,11 +666,14 @@ def test_online_rl( # manually - one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized) - another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) + one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) + another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) combined_experience = combine_experiences([one_experience, another_experience]) + if store_agent_embed: + assert exists(combined_experience.agent_embed) + actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage) actor_loss.backward()