From 49082d8629a08e94a733509f598f1854bda44f9f Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 2 Oct 2025 08:36:00 -0700 Subject: [PATCH] x-space and v-space prediction in dynamics model --- dreamer4/dreamer4.py | 42 +++++++++++++++++++++++++++++++++++++++--- tests/test_dreamer.py | 12 ++++++++---- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c353c54..627b33c 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -509,6 +509,7 @@ class DynamicsModel(Module): num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) num_register_tokens = 8, # they claim register tokens led to better temporal consistency depth = 4, + pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) time_block_every = 4, # every 4th block is time attn_kwargs: dict = dict( dim_head = 64, @@ -532,9 +533,14 @@ class DynamicsModel(Module): assert divisible_by(dim, 2) dim_half = dim // 2 + self.num_signal_levels = num_signal_levels + self.num_step_sizes = num_step_sizes + self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) + self.pred_orig_latent = pred_orig_latent + # they sum all the actions into a single token self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) @@ -573,6 +579,32 @@ class DynamicsModel(Module): step_sizes = None # (b t) ): + assert not (exists(signal_levels) ^ exists(step_sizes)) + + flow_matching = exists(signal_levels) + + # flow matching if `signal_levels` passed in + + if flow_matching: + + noise = torch.randn_like(latents) + + interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') + + orig_latents = latents + + latents = noise.lerp(latents, interp) + + # allow for original velocity pred + # x-space as in paper is in else clause + + if not self.pred_orig_latent: + pred_target = flow = latents - noise + else: + pred_target = latents + + # latents to spatial tokens + space_tokens = self.latents_to_spatial_tokens(latents) # pack to tokens @@ -584,14 +616,13 @@ class DynamicsModel(Module): # determine signal + step size embed for their diffusion forcing + shortcut - assert not (exists(signal_levels) ^ exists(step_sizes)) - if exists(signal_levels): signal_embed = self.signal_levels_embed(signal_levels) step_size_embed = self.step_sizes_embed(step_sizes) flow_token = cat((signal_embed, step_size_embed), dim = -1) flow_token = rearrange(flow_token, 'b t d -> b t d') + else: flow_token = registers[..., 0:0, :] @@ -619,7 +650,12 @@ class DynamicsModel(Module): pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean') - return self.to_pred(pooled) + pred = self.to_pred(pooled) + + if not flow_matching: + return pred + + return F.mse_loss(pred, pred_target) # dreamer diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 6257b59..d64ac69 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -1,7 +1,11 @@ import pytest +param = pytest.mark.parametrize import torch -def test_e2e(): +@param('pred_orig_latent', (False, True)) +def test_e2e( + pred_orig_latent +): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32) @@ -13,10 +17,10 @@ def test_e2e(): latents = tokenizer(x, return_latents = True) assert latents.shape[-1] == 32 - dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32) + dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent) signal_levels = torch.randint(0, 500, (2, 4)) step_sizes = torch.randint(0, 32, (2, 4)) - pred = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes) - assert pred.shape == latents.shape + flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes) + assert flow_loss.numel() == 1