correct a misunderstanding where past actions is a separate action token, while agent token is used for the prediction of next action, rewards, values

This commit is contained in:
lucidrains 2025-10-12 16:16:18 -07:00
parent 9c78962736
commit 6dbdc3d7d8
2 changed files with 29 additions and 21 deletions

View File

@ -87,6 +87,9 @@ def is_power_two(num):
# tensor helpers
def is_empty(t):
return t.numel() == 0
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
@ -1944,20 +1947,6 @@ class DynamicsWorldModel(Module):
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
# maybe add the action embed to the agent tokens per time step
if exists(discrete_actions) or exists(continuous_actions):
assert self.action_embedder.has_actions
action_embed = self.action_embedder(
discrete_actions = discrete_actions,
discrete_action_types = discrete_action_types,
continuous_actions = continuous_actions,
continuous_action_types = continuous_action_types
)
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed)
# maybe add a reward embedding to agent tokens
if exists(rewards):
@ -1975,9 +1964,23 @@ class DynamicsWorldModel(Module):
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
# maybe create the action tokens
if exists(discrete_actions) or exists(continuous_actions):
assert self.action_embedder.has_actions
action_tokens = self.action_embedder(
discrete_actions = discrete_actions,
discrete_action_types = discrete_action_types,
continuous_actions = continuous_actions,
continuous_action_types = continuous_action_types
)
else:
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False):
# latents to spatial tokens
space_tokens = self.latents_to_spatial_tokens(noised_latents)
@ -1986,6 +1989,10 @@ class DynamicsWorldModel(Module):
num_spatial_tokens = space_tokens.shape[-2]
# action tokens
num_action_tokens = 1 if not is_empty(action_tokens) else 0
# pack to tokens
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
@ -2003,7 +2010,7 @@ class DynamicsWorldModel(Module):
# pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, agent_tokens], 'b t * d')
# attend functions for space and time
@ -2015,6 +2022,7 @@ class DynamicsWorldModel(Module):
space_seq_len = (
+ 1 # signal + step
+ num_action_tokens # past action tokens - todo: account for multi-agent
+ self.num_agents # action / agent tokens
+ self.num_register_tokens
+ num_spatial_tokens
@ -2056,7 +2064,7 @@ class DynamicsWorldModel(Module):
# unpack
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
flow_token, space_tokens, register_tokens, action_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
# pooling
@ -2071,7 +2079,7 @@ class DynamicsWorldModel(Module):
# forward the network
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True)
if return_pred_only:
if not return_agent_tokens:
@ -2108,7 +2116,7 @@ class DynamicsWorldModel(Module):
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
half_step_size = 2 ** step_sizes_log2_minus_one
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens)
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens)
# first derive b'
@ -2127,7 +2135,7 @@ class DynamicsWorldModel(Module):
# get second prediction for b''
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens)
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens)
if is_v_space_pred:
second_step_pred_flow = second_step_pred

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.17"
version = "0.0.18"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }