handle spatial tokens less than latent tokens in dynamics model

This commit is contained in:
lucidrains 2025-10-06 09:19:27 -07:00
parent bfbecb4968
commit 25b8de91cc
2 changed files with 62 additions and 21 deletions

View File

@ -37,7 +37,7 @@ from accelerate import Accelerator
import einx import einx
from einops import einsum, rearrange, repeat, reduce, pack, unpack from einops import einsum, rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange, Reduce
# flex attention - but will make sure it works if it is not available # flex attention - but will make sure it works if it is not available
# may also end up crafting own custom flash attention kernel for this work # may also end up crafting own custom flash attention kernel for this work
@ -993,7 +993,8 @@ class DynamicsModel(Module):
video_tokenizer: VideoTokenizer | None = None, video_tokenizer: VideoTokenizer | None = None,
max_steps = 64, # K_max in paper max_steps = 64, # K_max in paper
num_register_tokens = 8, # they claim register tokens led to better temporal consistency num_register_tokens = 8, # they claim register tokens led to better temporal consistency
num_spatial_tokens_per_latent = 2, # latents can be projected to greater number of tokens num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
num_latent_tokens = None,
num_tasks = 0, num_tasks = 0,
depth = 4, depth = 4,
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
@ -1014,13 +1015,50 @@ class DynamicsModel(Module):
self.video_tokenizer = video_tokenizer self.video_tokenizer = video_tokenizer
# spatial and register tokens if exists(video_tokenizer):
num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens)
assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
assert exists(num_latent_tokens), '`num_latent_tokens` must be set'
# spatial
self.num_latent_tokens = num_latent_tokens
if num_spatial_tokens >= num_latent_tokens:
assert divisible_by(num_spatial_tokens, num_latent_tokens)
expand_factor = num_spatial_tokens // num_latent_tokens
self.latents_to_spatial_tokens = Sequential( self.latents_to_spatial_tokens = Sequential(
Linear(dim_latent, dim * num_spatial_tokens_per_latent), Linear(dim_latent, dim * expand_factor),
Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens_per_latent) Rearrange('... (s d) -> ... s d', s = expand_factor)
) )
self.to_latent_pred = Sequential(
Reduce('b t n s d -> b t n d', 'mean'),
RMSNorm(dim),
LinearNoBias(dim, dim_latent)
)
else:
assert divisible_by(num_latent_tokens, num_spatial_tokens)
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
self.latents_to_spatial_tokens = Sequential(
Rearrange('b t n d -> b t (n d)'),
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
)
self.to_latent_pred = Sequential(
RMSNorm(dim),
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
)
# register tokens
self.num_register_tokens = num_register_tokens self.num_register_tokens = num_register_tokens
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2) self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
@ -1084,13 +1122,6 @@ class DynamicsModel(Module):
self.layers = ModuleList(layers) self.layers = ModuleList(layers)
self.is_time = is_time self.is_time = is_time
# to prediction
self.to_pred = Sequential(
RMSNorm(dim),
Linear(dim, dim_latent)
)
def parameter(self): def parameter(self):
params = super().parameters() params = super().parameters()
@ -1121,6 +1152,8 @@ class DynamicsModel(Module):
if latents.ndim == 3: if latents.ndim == 3:
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
assert latents.shape[-2] == self.num_latent_tokens
# variables # variables
batch, time, device = *latents.shape[:2], latents.device batch, time, device = *latents.shape[:2], latents.device
@ -1262,9 +1295,7 @@ class DynamicsModel(Module):
space_tokens = inverse_pack_space_per_latent(space_tokens) space_tokens = inverse_pack_space_per_latent(space_tokens)
pooled = reduce(space_tokens, 'b t nl s d -> b t nl d', 'mean') pred = self.to_latent_pred(space_tokens)
pred = self.to_pred(pooled)
return pred return pred

View File

@ -7,16 +7,24 @@ import torch
@param('dynamics_with_video_input', (False, True)) @param('dynamics_with_video_input', (False, True))
@param('prob_no_shortcut_train', (None, 0., 1.)) @param('prob_no_shortcut_train', (None, 0., 1.))
@param('add_task_embeds', (False, True)) @param('add_task_embeds', (False, True))
@param('num_spatial_tokens', (2, 8))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
grouped_query_attn, grouped_query_attn,
dynamics_with_video_input, dynamics_with_video_input,
prob_no_shortcut_train, prob_no_shortcut_train,
add_task_embeds add_task_embeds,
num_spatial_tokens
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32) tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
num_latent_tokens = 4
)
video = torch.randn(2, 3, 4, 256, 256) video = torch.randn(2, 3, 4, 256, 256)
loss = tokenizer(video) loss = tokenizer(video)
@ -33,6 +41,8 @@ def test_e2e(
dim_latent = 32, dim_latent = 32,
max_steps = 64, max_steps = 64,
num_tasks = 4, num_tasks = 4,
num_latent_tokens = 4,
num_spatial_tokens = num_spatial_tokens,
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads, heads = heads,