last commit for the day - take care of the task embed

This commit is contained in:
lucidrains 2025-10-05 11:40:48 -07:00
parent fe99efecba
commit f507afa0d3
2 changed files with 44 additions and 10 deletions

View File

@ -982,6 +982,7 @@ class DynamicsModel(Module):
max_steps = 64, # K_max in paper
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
num_tasks = 0,
depth = 4,
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
time_block_every = 4, # every 4th block is time
@ -1029,10 +1030,15 @@ class DynamicsModel(Module):
self.pred_orig_latent = pred_orig_latent # x-space or v-space
self.loss_weight_fn = loss_weight_fn
# reinforcement related
# they sum all the actions into a single token
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
self.num_task = num_tasks
self.task_embed = nn.Embedding(num_tasks, dim)
# calculate "space" seq len
self.space_seq_len = (
@ -1096,6 +1102,7 @@ class DynamicsModel(Module):
latents = None, # (b t d)
signal_levels = None, # (b t)
step_sizes_log2 = None, # (b)
tasks = None, # (b)
return_pred_only = False
):
# handle video or latents
@ -1145,7 +1152,19 @@ class DynamicsModel(Module):
noised_latents = noise.lerp(latents, times)
def get_prediction(noised_latents, signal_levels, step_sizes_log2):
# reinforcementnet learning related
agent_tokens = repeat(self.action_learned_embed, 'd -> b d', b = batch)
if exists(tasks):
assert self.num_tasks > 0
task_embeds = self.task_embed(tasks)
agent_tokens = agent_tokens + task_embeds
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens):
# latents to spatial tokens
space_tokens = self.latents_to_spatial_tokens(noised_latents)
@ -1155,8 +1174,6 @@ class DynamicsModel(Module):
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
agent_token = repeat(self.action_learned_embed, 'd -> b t d', b = batch, t = time)
# determine signal + step size embed for their diffusion forcing + shortcut
signal_embed = self.signal_levels_embed(signal_levels)
@ -1167,9 +1184,13 @@ class DynamicsModel(Module):
flow_token = cat((signal_embed, step_size_embed), dim = -1)
flow_token = rearrange(flow_token, 'b t d -> b t d')
# handle agent tokens w/ actions and task embeds
agent_tokens = repeat(agent_tokens, 'b d -> b t d', t = time)
# pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
# attend functions for space and time
@ -1211,7 +1232,7 @@ class DynamicsModel(Module):
# unpack
flow_token, space_tokens, register_tokens, agent_token = unpack(tokens, packed_tokens_shape, 'b t * d')
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
# pooling
@ -1223,7 +1244,7 @@ class DynamicsModel(Module):
# forward the network
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2)
pred = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens)
if return_pred_only:
return pred
@ -1257,7 +1278,7 @@ class DynamicsModel(Module):
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
half_step_size = 2 ** step_sizes_log2_minus_one
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens)
# first derive b'
@ -1273,7 +1294,7 @@ class DynamicsModel(Module):
# get second prediction for b''
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one)
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one, agent_tokens)
if is_v_space_pred:
second_step_pred_flow = second_step_pred

View File

@ -6,11 +6,13 @@ import torch
@param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True))
@param('prob_no_shortcut_train', (None, 0., 1.))
@param('add_task_embeds', (False, True))
def test_e2e(
pred_orig_latent,
grouped_query_attn,
dynamics_with_video_input,
prob_no_shortcut_train
prob_no_shortcut_train,
add_task_embeds
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
@ -30,6 +32,7 @@ def test_e2e(
video_tokenizer = tokenizer,
dim_latent = 32,
max_steps = 64,
num_tasks = 4,
pred_orig_latent = pred_orig_latent,
attn_kwargs = dict(
heads = heads,
@ -46,7 +49,17 @@ def test_e2e(
else:
dynamics_input = dict(latents = latents)
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2)
tasks = None
if add_task_embeds:
tasks = torch.randint(0, 4, (2,))
flow_loss = dynamics(
**dynamics_input,
tasks = tasks,
signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2
)
assert flow_loss.numel() == 1
def test_symexp_two_hot():