This commit is contained in:
lucidrains 2025-10-02 11:49:22 -07:00
parent 51e0852604
commit e23a5294ec
2 changed files with 22 additions and 18 deletions

View File

@ -531,7 +531,7 @@ class DynamicsModel(Module):
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
num_register_tokens = 8, # they claim register tokens led to better temporal consistency num_register_tokens = 8, # they claim register tokens led to better temporal consistency
depth = 4, depth = 4,
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) pred_is_clean_latents = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
time_block_every = 4, # every 4th block is time time_block_every = 4, # every 4th block is time
attn_kwargs: dict = dict( attn_kwargs: dict = dict(
dim_head = 64, dim_head = 64,
@ -561,7 +561,7 @@ class DynamicsModel(Module):
self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half) self.signal_levels_embed = nn.Embedding(num_signal_levels, dim_half)
self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half) self.step_sizes_embed = nn.Embedding(num_step_sizes, dim_half)
self.pred_orig_latent = pred_orig_latent self.pred_is_clean_latents = pred_is_clean_latents
# they sum all the actions into a single token # they sum all the actions into a single token
@ -611,19 +611,15 @@ class DynamicsModel(Module):
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') times = signal_levels.float() / self.num_signal_levels
orig_latents = latents times = rearrange(times, 'b t -> b t 1')
latents = noise.lerp(latents, interp) flow = latents - noise
# allow for original velocity pred latents = noise.lerp(latents, times)
# x-space as in paper is in else clause
if not self.pred_orig_latent: noised_latents = latents
pred_target = flow = latents - noise
else:
pred_target = latents
# latents to spatial tokens # latents to spatial tokens
@ -677,7 +673,15 @@ class DynamicsModel(Module):
if not flow_matching: if not flow_matching:
return pred return pred
return F.mse_loss(pred, pred_target) # x-space vs v-space
if self.pred_is_clean_latents:
denoised_latent = pred
pred_flow = (denoised_latent - noised_latents) / (1. - times)
else:
pred_flow = pred
return F.mse_loss(pred_flow, flow)
# dreamer # dreamer

View File

@ -2,9 +2,9 @@ import pytest
param = pytest.mark.parametrize param = pytest.mark.parametrize
import torch import torch
@param('pred_orig_latent', (False, True)) @param('pred_is_clean_latents', (False, True))
def test_e2e( def test_e2e(
pred_orig_latent pred_is_clean_latents
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
@ -17,7 +17,7 @@ def test_e2e(
latents = tokenizer(x, return_latents = True) latents = tokenizer(x, return_latents = True)
assert latents.shape[-1] == 32 assert latents.shape[-1] == 32
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent) dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_is_clean_latents = pred_is_clean_latents)
signal_levels = torch.randint(0, 500, (2, 4)) signal_levels = torch.randint(0, 500, (2, 4))
step_sizes = torch.randint(0, 32, (2, 4)) step_sizes = torch.randint(0, 32, (2, 4))