From 2e7f406d4974212f9d4239c59d1fe0b5222019a1 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 13 Nov 2025 16:37:35 -0800 Subject: [PATCH] allow for the combining of experiences from environment and dream --- dreamer4/dreamer4.py | 20 ++++++++++++++++---- pyproject.toml | 2 +- tests/test_dreamer.py | 8 ++++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9865192..3881307 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -96,7 +96,7 @@ class Experience: lens: MaybeTensor = None is_truncated: MaybeTensor = None agent_index: int = 0 - is_from_world_model: bool = True + is_from_world_model: bool | Tensor = True def cpu(self): return self.to(torch.device('cpu')) @@ -124,6 +124,9 @@ def combine_experiences( if not exists(exp.is_truncated): exp.is_truncated = full((batch,), True, device = device) + if isinstance(exp.is_from_world_model, bool): + exp.is_from_world_model = tensor(exp.is_from_world_model) + # convert to dictionary exps_dict = [asdict(exp) for exp in exps] @@ -145,11 +148,15 @@ def combine_experiences( for field_values in all_field_values: - if is_tensor(first(field_values)): + first_value = first(field_values) + + if is_tensor(first_value): field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2)) - new_field_value = cat(field_values) + cat_or_stack = cat if first_value.ndim > 0 else stack + + new_field_value = cat_or_stack(field_values) else: new_field_value = first(list(set(field_values))) @@ -2408,7 +2415,7 @@ class DynamicsWorldModel(Module): env, seed = None, agent_index = 0, - step_size = 4, + num_steps = 4, max_timesteps = 16, env_is_vectorized = False, use_time_cache = True, @@ -2448,6 +2455,11 @@ class DynamicsWorldModel(Module): episode_lens = full((batch,), 0, device = device) + # derive step size + + assert divisible_by(self.max_steps, num_steps) + step_size = self.max_steps // num_steps + # maybe time kv cache time_cache = None diff --git a/pyproject.toml b/pyproject.toml index 72c6d3f..12ac647 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.19" +version = "0.1.20" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index eab9d36..c1ab74c 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -643,7 +643,9 @@ def test_online_rl( dim_latent = 16, patch_size = 32, attn_dim_head = 16, - num_latent_tokens = 1 + num_latent_tokens = 1, + image_height = 256, + image_width = 256, ) world_model_and_policy = DynamicsWorldModel( @@ -677,10 +679,12 @@ def test_online_rl( # manually + dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True) + one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) - combined_experience = combine_experiences([one_experience, another_experience]) + combined_experience = combine_experiences([dream_experience, one_experience, another_experience]) # quick test moving the experience to different devices