From bfbecb4968659e56bbdd74912bc975515b321be6 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 6 Oct 2025 08:16:55 -0700 Subject: [PATCH] an anonymous researcher pointed out that the video tokenizer may be using multiple latents per frame --- dreamer4/dreamer4.py | 93 +++++++++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 4c0fe41..de03645 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -98,6 +98,15 @@ def pad_at_dim( zeros = ((0, 0) * dims_from_right) return F.pad(t, (*zeros, *pad), value = value) +def align_dims_left(t, aligned_to): + shape = t.shape + num_right_dims = aligned_to.ndim - t.ndim + + if num_right_dims < 0: + return + + return t.reshape(*shape, *((1,) * num_right_dims)) + def l2norm(t): return F.normalize(t, dim = -1, p = 2) @@ -677,6 +686,7 @@ class VideoTokenizer(Module): dim, dim_latent, patch_size, + num_latent_tokens = 4, encoder_depth = 4, decoder_depth = 4, attn_kwargs: dict = dict(), @@ -701,7 +711,9 @@ class VideoTokenizer(Module): # special tokens - self.latent_token = Parameter(torch.randn(dim) * 1e-2) + assert num_latent_tokens >= 1 + self.num_latent_tokens = num_latent_tokens + self.latent_tokens = Parameter(torch.randn(num_latent_tokens, dim) * 1e-2) # mae masking - Kaiming He paper from long ago @@ -829,7 +841,7 @@ class VideoTokenizer(Module): # give the latents an out of bounds position and assume the network will figure it out - positions = pad_at_dim(positions, (0, 1), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated + positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated positions = rearrange(positions, 't hw p -> (t hw) p') @@ -855,7 +867,7 @@ class VideoTokenizer(Module): # add the latent - latents = repeat(self.latent_token, 'd -> b t d', b = tokens.shape[0], t = tokens.shape[1]) + latents = repeat(self.latent_tokens, 'n d -> b t n d', b = tokens.shape[0], t = tokens.shape[1]) tokens, packed_latent_shape = pack((tokens, latents), 'b t * d') @@ -917,7 +929,11 @@ class VideoTokenizer(Module): decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor) decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time) - tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d') + tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d') + + # pack time + + tokens, inverse_pack_time = pack_one(tokens, 'b * d') # decoder attend @@ -936,13 +952,9 @@ class VideoTokenizer(Module): tokens = inverse_pack_time(tokens) - # excise latents + # unpack latents - tokens = tokens[..., :-1, :] - - # unpack space - - tokens = inverse_pack_space(tokens) + tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d') # project back to patches @@ -979,9 +991,9 @@ class DynamicsModel(Module): dim, dim_latent, video_tokenizer: VideoTokenizer | None = None, - max_steps = 64, # K_max in paper - num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction) - num_register_tokens = 8, # they claim register tokens led to better temporal consistency + max_steps = 64, # K_max in paper + num_register_tokens = 8, # they claim register tokens led to better temporal consistency + num_spatial_tokens_per_latent = 2, # latents can be projected to greater number of tokens num_tasks = 0, depth = 4, pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) @@ -1005,10 +1017,11 @@ class DynamicsModel(Module): # spatial and register tokens self.latents_to_spatial_tokens = Sequential( - Linear(dim_latent, dim * num_spatial_tokens), - Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens) + Linear(dim_latent, dim * num_spatial_tokens_per_latent), + Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens_per_latent) ) + self.num_register_tokens = num_register_tokens self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2) # signal and step sizes @@ -1039,15 +1052,6 @@ class DynamicsModel(Module): self.num_tasks = num_tasks self.task_embed = nn.Embedding(num_tasks, dim) - # calculate "space" seq len - - self.space_seq_len = ( - 1 # action / agent token - + 1 # signal + step - + num_register_tokens - + num_spatial_tokens - ) - # attention self.attn_softclamp_value = attn_softclamp_value @@ -1099,7 +1103,7 @@ class DynamicsModel(Module): self, *, video = None, - latents = None, # (b t d) + latents = None, # (b t n d) | (b t d) signal_levels = None, # (b t) step_sizes_log2 = None, # (b) tasks = None, # (b) @@ -1114,6 +1118,11 @@ class DynamicsModel(Module): latents = self.video_tokenizer.tokenize(video) + if latents.ndim == 3: + latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case + + # variables + batch, time, device = *latents.shape[:2], latents.device # flow related @@ -1146,7 +1155,11 @@ class DynamicsModel(Module): # times is from 0 to 1 - times = rearrange(signal_levels.float() / self.max_steps, 'b t -> b t 1') + def get_times_from_signal_level(signal_levels): + times = signal_levels.float() / self.max_steps + return align_dims_left(times, latents) + + times = get_times_from_signal_level(signal_levels) # noise from 0 as noise to 1 as data @@ -1169,6 +1182,10 @@ class DynamicsModel(Module): space_tokens = self.latents_to_spatial_tokens(noised_latents) + space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d') + + num_spatial_tokens = space_tokens.shape[-2] + # pack to tokens # [signal + step size embed] [latent space tokens] [register] [actions / agent] @@ -1200,7 +1217,14 @@ class DynamicsModel(Module): attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) - space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality + space_seq_len = ( + 1 # action / agent token + + 1 # signal + step + + self.num_register_tokens + + num_spatial_tokens + ) + + space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) @@ -1236,7 +1260,9 @@ class DynamicsModel(Module): # pooling - pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean') + space_tokens = inverse_pack_space_per_latent(space_tokens) + + pooled = reduce(space_tokens, 'b t nl s d -> b t nl d', 'mean') pred = self.to_pred(pooled) @@ -1285,21 +1311,24 @@ class DynamicsModel(Module): if is_v_space_pred: first_step_pred_flow = first_step_pred else: - first_times = signal_levels[..., None].float() / self.max_steps + first_times = get_times_from_signal_level(signal_levels) first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times) # take a half step - denoised_latent = noised_latents + first_step_pred_flow * (half_step_size[:, None, None] / self.max_steps) + half_step_size_align_left = align_dims_left(half_step_size, noised_latents) + + denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps) # get second prediction for b'' - second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one, agent_tokens) + signal_levels_plus_half_step = signal_levels + half_step_size[:, None] + second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens) if is_v_space_pred: second_step_pred_flow = second_step_pred else: - second_times = signal_levels[..., None].float() / self.max_steps + second_times = get_times_from_signal_level(signal_levels_plus_half_step) second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times) # pred target is sg(b' + b'') / 2