From 1b7f6e787d69cc5ab1b374fd531469ff7910bd97 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 4 Oct 2025 09:22:06 -0700 Subject: [PATCH] rotate in the 3d rotary embeddings for the video tokenizer for both encoder / decoder --- dreamer4/dreamer4.py | 140 ++++++++++++++++++++++++++++++------------- 1 file changed, 100 insertions(+), 40 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 50f5e90..e2e6e37 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -27,6 +27,7 @@ from accelerate import Accelerator # f - frequencies (rotary) # p - positions (3 for spacetime in this work) # t - time +# g - groups of query heads to key heads (gqa) # vc - video channels # vh, vw - video height and width @@ -66,6 +67,8 @@ def first(arr): def divisible_by(num, den): return (num % den) == 0 +# tensor helpers + def pack_one(t, pattern): packed, packed_shape = pack([t], pattern) @@ -75,6 +78,19 @@ def pack_one(t, pattern): return packed, inverse +def pad_at_dim( + t, + pad: tuple[int, int], + dim = -1, + value = 0. +): + if pad == (0, 0): + return t + + dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = ((0, 0) * dims_from_right) + return F.pad(t, (*zeros, *pad), value = value) + def l2norm(t): return F.normalize(t, dim = -1, p = 2) @@ -258,21 +274,30 @@ class GoldenGateRoPENd(Module): pos # (b n p) ): - freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p') - positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p') + freqs = rearrange(self.freqs, 'h f p -> h 1 f p') + positions = rearrange(pos.float(), 'n p -> 1 n 1 p') # thetas for freqs and positions (batch, head, seq, freq) - theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum') + theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum') return theta def apply_rotations( - theta # (b h n f) + theta, # (h n f) + qk ): - dtype = x + rotary_heads = theta.shape[0] + heads, dtype = qk.shape[1], qk - x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f) + # handle gqa for rotary + + if heads < rotary_heads: + assert divisible_by(heads, rotary_heads) + groups = heads // rotary_heads + theta = repeat(theta, 'h ... -> (h g) ...', g = groups) + + x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f) # apply rotations @@ -360,7 +385,7 @@ def softclamp_score_mod(value): # todo - reuse the inner function from flex attn above with broadcasting -def nonflex_block_causal_mask(seq_len, block_size, device = None): +def block_causal_mask(seq_len, block_size, device = None): blocks = ceil(seq_len / block_size) causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril() @@ -443,7 +468,6 @@ class Attention(Module): # scaling, splitting and merging of heads - self.scale = dim_head ** -0.5 self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head) self.merge_heads = Rearrange('b h n d -> b n (h d)') @@ -464,13 +488,8 @@ class Attention(Module): tokens, # (b n d) kv_cache = None, return_kv_cache = False, - attend_fn: Callable | None = None, - attend_kwargs: dict = dict( - softclamp_value = None, - causal = False, - mask = None, - scale = None - ) + rotary_pos_emb = None, + attend_fn: Callable | None = None ): tokens, inverse_packed_batch = pack_one(tokens, '* n d') @@ -494,11 +513,17 @@ class Attention(Module): k = cat((ck, k), dim = -2) v = cat((cv, v), dim = -2) + # rotary + + if exists(rotary_pos_emb): + q = apply_rotations(rotary_pos_emb, q) + k = apply_rotations(rotary_pos_emb, k) + # attention attend_fn = default(attend_fn, naive_attend) - out = attend_fn(q, k, v, **attend_kwargs) + out = attend_fn(q, k, v) # merge heads @@ -550,17 +575,21 @@ class VideoTokenizer(Module): patch_size, encoder_depth = 4, decoder_depth = 4, - attn_kwargs: dict = dict( - dim_head = 64, - heads = 8, - ), + attn_kwargs: dict = dict(), + attn_dim_head = 64, + attn_heads = 8, attn_softclamp_value = 50., ff_kwargs: dict = dict(), decoder_pos_mlp_depth = 2, channels = 3, per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue lpips_loss_network: Module | None = None, - lpips_loss_weight = 0.2 + lpips_loss_weight = 0.2, + nd_rotary_kwargs: dict = dict( + rope_min_freq = 1., + rope_max_freq = 10000., + rope_p_zero_freqs = 0. + ) ): super().__init__() @@ -589,6 +618,15 @@ class VideoTokenizer(Module): Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size), ) + # 3d rotations + + self.spacetime_rotary = GoldenGateRoPENd( + dim_pos = 3, + heads = attn_heads, + dim_head = attn_dim_head, + **nd_rotary_kwargs + ) + # attention related self.attn_softclamp_value = attn_softclamp_value @@ -599,7 +637,7 @@ class VideoTokenizer(Module): for _ in range(encoder_depth): encoder_layers.append(ModuleList([ - Attention(dim = dim, **attn_kwargs), + Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs), SwiGLUFeedforward(dim = dim, **ff_kwargs) ])) @@ -630,7 +668,7 @@ class VideoTokenizer(Module): for _ in range(decoder_depth): decoder_layers.append(ModuleList([ - Attention(dim = dim, **attn_kwargs), + Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs), SwiGLUFeedforward(dim = dim, **ff_kwargs) ])) @@ -662,11 +700,9 @@ class VideoTokenizer(Module): mask_patches = None, return_all_losses = False ): - batch, time = video.shape[0], video.shape[2] + batch, _, time, height, width = video.shape patch_size, device = self.patch_size, video.device - *_, height, width = video.shape - assert divisible_by(height, patch_size) and divisible_by(width, patch_size) # to tokens @@ -677,6 +713,24 @@ class VideoTokenizer(Module): num_patch_height, num_patch_width, _ = tokens.shape[-3:] + # rotary positions + + positions = stack(torch.meshgrid( + arange(time, device = device), + arange(num_patch_height, device = device), + arange(num_patch_width, device = device) + ), dim = -1) + + positions = rearrange(positions, 't h w p -> t (h w) p') + + # give the latents an out of bounds position and assume the network will figure it out + + positions = pad_at_dim(positions, (0, 1), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated + + positions = rearrange(positions, 't hw p -> (t hw) p') + + rotary_pos_emb = self.spacetime_rotary(positions) + # masking mask_patches = default(mask_patches, self.training) @@ -697,9 +751,9 @@ class VideoTokenizer(Module): # add the latent - latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1]) + latents = repeat(self.latent_token, 'd -> b t d', b = tokens.shape[0], t = tokens.shape[1]) - tokens = cat((tokens, latents), dim = -2) + tokens, packed_latent_shape = pack((tokens, latents), 'b t * d') # pack time @@ -709,10 +763,12 @@ class VideoTokenizer(Module): attend_kwargs = dict(softclamp_value = self.attn_softclamp_value) + attend_fn = partial(naive_attend, **attend_kwargs) + # encoder for attn, ff in self.encoder_layers: - tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens + tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens tokens = ff(tokens) + tokens tokens = self.encoder_norm(tokens) @@ -720,9 +776,10 @@ class VideoTokenizer(Module): # latent bottleneck tokens = inverse_pack_time(tokens) - tokens = tokens[..., -1, :] - latents = self.encoded_to_latents(tokens) + tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d') + + latents = self.encoded_to_latents(latents) if return_latents: return latents @@ -744,7 +801,8 @@ class VideoTokenizer(Module): # decoder attention for attn, ff in self.decoder_layers: - tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens + tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens + tokens = ff(tokens) + tokens tokens = self.decoder_norm(tokens) @@ -959,6 +1017,14 @@ class DynamicsModel(Module): tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d') + # attend functions for space and time + + attend_kwargs = dict(softclamp_value = self.attn_softclamp_value) + + space_attend = partial(naive_attend, causal = False, **attend_kwargs) + + time_attend = partial(naive_attend, causal = True, **attend_kwargs) + # attention for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time): @@ -967,17 +1033,11 @@ class DynamicsModel(Module): # when is a axial time attention block, should be causal - attend_kwargs = dict() - - if layer_is_time: - attend_kwargs.update( - softclamp_value = self.attn_softclamp_value, - causal = True - ) + attend_fn = time_attend if layer_is_time else space_attend # attention layer - tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens + tokens = attn(tokens, attend_fn = attend_fn) + tokens tokens = post_attn_rearrange(tokens)