From 2d20d0a6c17fecbbc424a523885cfb0e81d7be22 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 16 Oct 2025 10:15:43 -0700 Subject: [PATCH] able to roll out actions from one agent within the dreams of a world model --- dreamer4/dreamer4.py | 49 ++++++++++++++++++++++++++++++++++++++++--- pyproject.toml | 2 +- tests/test_dreamer.py | 40 +++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 8384384..344049b 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -63,6 +63,8 @@ TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone')) +WorldModelGenerations = namedtuple('WorldModelGenerations', ('video', 'latents', 'rewards', 'actions')) + # helpers def exists(v): @@ -74,6 +76,9 @@ def default(v, d): def first(arr): return arr[0] +def has_at_least_one(*bools): + return sum([*map(int, bools)]) > 0 + def ensure_tuple(t): return (t,) if not isinstance(t, tuple) else t @@ -94,6 +99,16 @@ def is_empty(t): def log(t, eps = 1e-20): return t.clamp(min = eps).log() +def safe_cat(*tensors, dim): + tensors = [*filter(exists, tensors)] + + if len(tensors) == 0: + return None + elif len(tensors) == 1: + return tensors[0] + + return cat(tensors, dim = dim) + def gumbel_noise(t): noise = torch.rand_like(t) return -log(-log(noise)) @@ -1630,7 +1645,8 @@ class DynamicsWorldModel(Module): image_width = None, return_decoded_video = None, context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - return_rewards_per_frame = False + return_rewards_per_frame = False, + return_agent_actions = False ): # (b t n d) | (b c t h w) @@ -1653,6 +1669,14 @@ class DynamicsWorldModel(Module): past_context_noise = latents.clone() + # maybe return actions + + if return_agent_actions: + assert self.action_embedder.has_actions + + decoded_discrete_actions = None + decoded_continuous_actions = None + # maybe return rewards if return_rewards_per_frame: @@ -1679,6 +1703,8 @@ class DynamicsWorldModel(Module): signal_levels = signal_levels_with_context, step_sizes = step_size, rewards = decoded_rewards, + discrete_actions = decoded_discrete_actions, + continuous_actions = decoded_continuous_actions, latent_is_noised = True, return_pred_only = True, return_agent_tokens = True @@ -1710,6 +1736,18 @@ class DynamicsWorldModel(Module): decoded_rewards = cat((decoded_rewards, pred_reward), dim = 1) + # decode the agent actions if needed + + if return_agent_actions: + one_agent_embed = agent_embed[:, -1:, agent_index] + + policy_embed = self.policy_head(one_agent_embed) + + sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed) + + decoded_discrete_actions = safe_cat(decoded_discrete_actions, sampled_discrete_actions, dim = 1) + decoded_continuous_actions = safe_cat(decoded_continuous_actions, sampled_continuous_actions, dim = 1) + # concat the denoised latent latents = cat((latents, denoised_latent), dim = 1) @@ -1739,10 +1777,15 @@ class DynamicsWorldModel(Module): width = image_width ) - if not return_rewards_per_frame: + if not has_at_least_one(return_rewards_per_frame, return_agent_actions): return generated_video - return generated_video, decoded_rewards + return WorldModelGenerations( + video = generated_video, + latents = latents, + rewards = decoded_rewards if return_rewards_per_frame else None, + actions = (decoded_discrete_actions, decoded_continuous_actions) if return_agent_actions else None + ) def forward( self, diff --git a/pyproject.toml b/pyproject.toml index 540bf3f..3167f30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.24" +version = "0.0.25" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 484470d..ae3cff7 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -187,6 +187,46 @@ def test_attend_factory( assert torch.allclose(flex_out, out, atol = 1e-6) +def test_action_with_world_model(): + from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel + + tokenizer = VideoTokenizer( + 512, + dim_latent = 32, + patch_size = 32, + encoder_depth = 1, + decoder_depth = 1, + attn_heads = 8, + image_height = 256, + image_width = 256, + attn_kwargs = dict( + query_heads = 16 + ) + ) + + dynamics = DynamicsWorldModel( + 512, + num_agents = 1, + video_tokenizer = tokenizer, + dim_latent = 32, + num_discrete_actions = 4 + ) + + rewards = torch.randn(1, 4) + discrete_actions = torch.randint(0, 4, (1, 4, 1)) + + generated_video, _, generated_rewards, (discrete_actions, continuous_actions) = dynamics.generate( + 10, + return_rewards_per_frame = True, + return_agent_actions = True + ) + + assert generated_video.shape == (1, 3, 10, 256, 256) + assert generated_rewards.shape == (1, 10) + + assert discrete_actions.shape == (1, 10, 1) + assert continuous_actions is None + def test_action_embedder(): from dreamer4.dreamer4 import ActionEmbedder