able to accept raw video for dynamics model, if tokenizer passed in

This commit is contained in:
lucidrains 2025-10-04 06:57:54 -07:00
parent 8373cb13ec
commit 895a867a66
2 changed files with 43 additions and 8 deletions

View File

@ -698,6 +698,7 @@ class DynamicsModel(Module):
self,
dim,
dim_latent,
video_tokenizer: VideoTokenizer | None = None,
num_signal_levels = 500,
num_step_sizes = 32,
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
@ -714,6 +715,10 @@ class DynamicsModel(Module):
):
super().__init__()
# can accept raw video if tokenizer is passed in
self.video_tokenizer = video_tokenizer
# spatial and register tokens
self.latents_to_spatial_tokens = Sequential(
@ -769,12 +774,34 @@ class DynamicsModel(Module):
Linear(dim, dim_latent)
)
def parameter(self):
params = super().parameters()
if not exists(self.video_tokenizer):
return params
return list(set(params) - set(self.video_tokenizer.parameters()))
def forward(
self,
latents, # (b t d)
*,
video = None,
latents = None, # (b t d)
signal_levels = None, # (b t)
step_sizes = None # (b t)
):
# handle video or latents
assert exists(video) ^ exists(latents)
if exists(video):
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
with torch.no_grad():
self.video_tokenizer.eval()
latents = self.video_tokenizer(video, return_latents = True)
# flow related
assert not (exists(signal_levels) ^ exists(step_sizes))

View File

@ -3,26 +3,29 @@ param = pytest.mark.parametrize
import torch
@param('pred_orig_latent', (False, True))
@param('gqa', (False, True))
@param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True))
def test_e2e(
pred_orig_latent,
gqa
grouped_query_attn,
dynamics_with_video_input
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
x = torch.randn(2, 3, 4, 256, 256)
video = torch.randn(2, 3, 4, 256, 256)
loss = tokenizer(x)
loss = tokenizer(video)
assert loss.numel() == 1
latents = tokenizer(x, return_latents = True)
latents = tokenizer(video, return_latents = True)
assert latents.shape[-1] == 32
query_heads, heads = (16, 4) if gqa else (8, 8)
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
dynamics = DynamicsModel(
512,
video_tokenizer = tokenizer,
dim_latent = 32,
num_signal_levels = 500,
num_step_sizes = 32,
@ -36,7 +39,12 @@ def test_e2e(
signal_levels = torch.randint(0, 500, (2, 4))
step_sizes = torch.randint(0, 32, (2, 4))
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
if dynamics_with_video_input:
dynamics_input = dict(video = video)
else:
dynamics_input = dict(latents = latents)
flow_loss = dynamics(**dynamics_input, signal_levels = signal_levels, step_sizes = step_sizes)
assert flow_loss.numel() == 1
def test_symexp_two_hot():