From 6dbdc3d7d81690dac321bd189162b713d5845a5c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 12 Oct 2025 16:16:18 -0700 Subject: [PATCH] correct a misunderstanding where past actions is a separate action token, while agent token is used for the prediction of next action, rewards, values --- dreamer4/dreamer4.py | 48 ++++++++++++++++++++++++++------------------ pyproject.toml | 2 +- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 1c9e5e1..18bf488 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -87,6 +87,9 @@ def is_power_two(num): # tensor helpers +def is_empty(t): + return t.numel() == 0 + def log(t, eps = 1e-20): return t.clamp(min = eps).log() @@ -1944,20 +1947,6 @@ class DynamicsWorldModel(Module): agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) - # maybe add the action embed to the agent tokens per time step - - if exists(discrete_actions) or exists(continuous_actions): - assert self.action_embedder.has_actions - - action_embed = self.action_embedder( - discrete_actions = discrete_actions, - discrete_action_types = discrete_action_types, - continuous_actions = continuous_actions, - continuous_action_types = continuous_action_types - ) - - agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed) - # maybe add a reward embedding to agent tokens if exists(rewards): @@ -1975,9 +1964,23 @@ class DynamicsWorldModel(Module): agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds) + # maybe create the action tokens + + if exists(discrete_actions) or exists(continuous_actions): + assert self.action_embedder.has_actions + + action_tokens = self.action_embedder( + discrete_actions = discrete_actions, + discrete_action_types = discrete_action_types, + continuous_actions = continuous_actions, + continuous_action_types = continuous_action_types + ) + else: + action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens + # main function, needs to be defined as such for shortcut training - additional calls for consistency loss - def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False): + def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False): # latents to spatial tokens space_tokens = self.latents_to_spatial_tokens(noised_latents) @@ -1986,6 +1989,10 @@ class DynamicsWorldModel(Module): num_spatial_tokens = space_tokens.shape[-2] + # action tokens + + num_action_tokens = 1 if not is_empty(action_tokens) else 0 + # pack to tokens # [signal + step size embed] [latent space tokens] [register] [actions / agent] @@ -2003,7 +2010,7 @@ class DynamicsWorldModel(Module): # pack to tokens for attending - tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d') + tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, agent_tokens], 'b t * d') # attend functions for space and time @@ -2015,6 +2022,7 @@ class DynamicsWorldModel(Module): space_seq_len = ( + 1 # signal + step + + num_action_tokens # past action tokens - todo: account for multi-agent + self.num_agents # action / agent tokens + self.num_register_tokens + num_spatial_tokens @@ -2056,7 +2064,7 @@ class DynamicsWorldModel(Module): # unpack - flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') + flow_token, space_tokens, register_tokens, action_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') # pooling @@ -2071,7 +2079,7 @@ class DynamicsWorldModel(Module): # forward the network - pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True) + pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True) if return_pred_only: if not return_agent_tokens: @@ -2108,7 +2116,7 @@ class DynamicsWorldModel(Module): step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2 half_step_size = 2 ** step_sizes_log2_minus_one - first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens) + first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens) # first derive b' @@ -2127,7 +2135,7 @@ class DynamicsWorldModel(Module): # get second prediction for b'' signal_levels_plus_half_step = signal_levels + half_step_size[:, None] - second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens) + second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens) if is_v_space_pred: second_step_pred_flow = second_step_pred diff --git a/pyproject.toml b/pyproject.toml index 8832486..0c004a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.17" +version = "0.0.18" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }