diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9c22615..31df3d8 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -155,6 +155,15 @@ def is_power_two(num): def is_empty(t): return t.numel() == 0 +def lens_to_mask(t, max_len = None): + if not exists(max_len): + max_len = t.amax().item() + + device = t.device + seq = torch.arange(max_len, device = device) + + return einx.less('j, i -> i j', seq, t) + def log(t, eps = 1e-20): return t.clamp(min = eps).log() @@ -2581,6 +2590,7 @@ class DynamicsWorldModel(Module): *, video = None, # (b v? c t vh vw) latents = None, # (b t v? n d) | (b t v? d) + lens = None, # (b) signal_levels = None, # () | (b) | (b t) step_sizes = None, # () | (b) step_sizes_log2 = None, # () | (b) @@ -3014,7 +3024,19 @@ class DynamicsWorldModel(Module): flow_losses = flow_losses * loss_weight - flow_loss = flow_losses.mean() + # handle variable lengths if needed + + is_var_len = exists(lens) + + if is_var_len: + + loss_mask = lens_to_mask(lens, time) + loss_mask_without_last = loss_mask[:, :-1] + + flow_loss = flow_losses[loss_mask].mean() + + else: + flow_loss = flow_losses.mean() # now take care of the agent token losses @@ -3037,7 +3059,10 @@ class DynamicsWorldModel(Module): reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.) - reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9) + if is_var_len: + reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0) + else: + reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9) # maybe autoregressive action loss @@ -3080,12 +3105,20 @@ class DynamicsWorldModel(Module): if exists(discrete_log_probs): discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.) - discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean') + if is_var_len: + discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp') + discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean') + else: + discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean') if exists(continuous_log_probs): continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.) - continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean') + if is_var_len: + continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp') + continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean') + else: + continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean') # handle loss normalization diff --git a/pyproject.toml b/pyproject.toml index d0eedcd..c126938 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.71" +version = "0.0.72" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index d8ffc1e..87b0228 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -16,6 +16,7 @@ def exists(v): @param('num_residual_streams', (1, 4)) @param('add_reward_embed_to_agent_token', (False, True)) @param('use_time_kv_cache', (False, True)) +@param('var_len', (False, True)) def test_e2e( pred_orig_latent, grouped_query_attn, @@ -27,7 +28,8 @@ def test_e2e( condition_on_actions, num_residual_streams, add_reward_embed_to_agent_token, - use_time_kv_cache + use_time_kv_cache, + var_len ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel @@ -95,8 +97,13 @@ def test_e2e( if condition_on_actions: actions = torch.randint(0, 4, (2, 3, 1)) + lens = None + if var_len: + lens = torch.randint(1, 4, (2,)) + flow_loss = dynamics( **dynamics_input, + lens = lens, tasks = tasks, signal_levels = signal_levels, step_sizes_log2 = step_sizes_log2,