take a step towards variable lengthed experiences during training

This commit is contained in:
lucidrains 2025-10-25 10:45:34 -07:00
parent 77a40e8701
commit 3d5617d769
3 changed files with 46 additions and 6 deletions

View File

@ -155,6 +155,15 @@ def is_power_two(num):
def is_empty(t):
return t.numel() == 0
def lens_to_mask(t, max_len = None):
if not exists(max_len):
max_len = t.amax().item()
device = t.device
seq = torch.arange(max_len, device = device)
return einx.less('j, i -> i j', seq, t)
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
@ -2581,6 +2590,7 @@ class DynamicsWorldModel(Module):
*,
video = None, # (b v? c t vh vw)
latents = None, # (b t v? n d) | (b t v? d)
lens = None, # (b)
signal_levels = None, # () | (b) | (b t)
step_sizes = None, # () | (b)
step_sizes_log2 = None, # () | (b)
@ -3014,7 +3024,19 @@ class DynamicsWorldModel(Module):
flow_losses = flow_losses * loss_weight
flow_loss = flow_losses.mean()
# handle variable lengths if needed
is_var_len = exists(lens)
if is_var_len:
loss_mask = lens_to_mask(lens, time)
loss_mask_without_last = loss_mask[:, :-1]
flow_loss = flow_losses[loss_mask].mean()
else:
flow_loss = flow_losses.mean()
# now take care of the agent token losses
@ -3037,7 +3059,10 @@ class DynamicsWorldModel(Module):
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
if is_var_len:
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
else:
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
# maybe autoregressive action loss
@ -3080,12 +3105,20 @@ class DynamicsWorldModel(Module):
if exists(discrete_log_probs):
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
if is_var_len:
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
else:
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
if exists(continuous_log_probs):
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
if is_var_len:
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
else:
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
# handle loss normalization

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.71"
version = "0.0.72"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -16,6 +16,7 @@ def exists(v):
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True))
@param('use_time_kv_cache', (False, True))
@param('var_len', (False, True))
def test_e2e(
pred_orig_latent,
grouped_query_attn,
@ -27,7 +28,8 @@ def test_e2e(
condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token,
use_time_kv_cache
use_time_kv_cache,
var_len
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -95,8 +97,13 @@ def test_e2e(
if condition_on_actions:
actions = torch.randint(0, 4, (2, 3, 1))
lens = None
if var_len:
lens = torch.randint(1, 4, (2,))
flow_loss = dynamics(
**dynamics_input,
lens = lens,
tasks = tasks,
signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2,