rotate in the 3d rotary embeddings for the video tokenizer for both encoder / decoder

This commit is contained in:
lucidrains 2025-10-04 09:22:06 -07:00
parent 93f6738c9c
commit 1b7f6e787d

View File

@ -27,6 +27,7 @@ from accelerate import Accelerator
# f - frequencies (rotary)
# p - positions (3 for spacetime in this work)
# t - time
# g - groups of query heads to key heads (gqa)
# vc - video channels
# vh, vw - video height and width
@ -66,6 +67,8 @@ def first(arr):
def divisible_by(num, den):
return (num % den) == 0
# tensor helpers
def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern)
@ -75,6 +78,19 @@ def pack_one(t, pattern):
return packed, inverse
def pad_at_dim(
t,
pad: tuple[int, int],
dim = -1,
value = 0.
):
if pad == (0, 0):
return t
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
@ -258,21 +274,30 @@ class GoldenGateRoPENd(Module):
pos # (b n p)
):
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
freqs = rearrange(self.freqs, 'h f p -> h 1 f p')
positions = rearrange(pos.float(), 'n p -> 1 n 1 p')
# thetas for freqs and positions (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
return theta
def apply_rotations(
theta # (b h n f)
theta, # (h n f)
qk
):
dtype = x
rotary_heads = theta.shape[0]
heads, dtype = qk.shape[1], qk
x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
# handle gqa for rotary
if heads < rotary_heads:
assert divisible_by(heads, rotary_heads)
groups = heads // rotary_heads
theta = repeat(theta, 'h ... -> (h g) ...', g = groups)
x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
# apply rotations
@ -360,7 +385,7 @@ def softclamp_score_mod(value):
# todo - reuse the inner function from flex attn above with broadcasting
def nonflex_block_causal_mask(seq_len, block_size, device = None):
def block_causal_mask(seq_len, block_size, device = None):
blocks = ceil(seq_len / block_size)
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
@ -443,7 +468,6 @@ class Attention(Module):
# scaling, splitting and merging of heads
self.scale = dim_head ** -0.5
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
@ -464,13 +488,8 @@ class Attention(Module):
tokens, # (b n d)
kv_cache = None,
return_kv_cache = False,
attend_fn: Callable | None = None,
attend_kwargs: dict = dict(
softclamp_value = None,
causal = False,
mask = None,
scale = None
)
rotary_pos_emb = None,
attend_fn: Callable | None = None
):
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
@ -494,11 +513,17 @@ class Attention(Module):
k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)
# rotary
if exists(rotary_pos_emb):
q = apply_rotations(rotary_pos_emb, q)
k = apply_rotations(rotary_pos_emb, k)
# attention
attend_fn = default(attend_fn, naive_attend)
out = attend_fn(q, k, v, **attend_kwargs)
out = attend_fn(q, k, v)
# merge heads
@ -550,17 +575,21 @@ class VideoTokenizer(Module):
patch_size,
encoder_depth = 4,
decoder_depth = 4,
attn_kwargs: dict = dict(
dim_head = 64,
heads = 8,
),
attn_kwargs: dict = dict(),
attn_dim_head = 64,
attn_heads = 8,
attn_softclamp_value = 50.,
ff_kwargs: dict = dict(),
decoder_pos_mlp_depth = 2,
channels = 3,
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
lpips_loss_network: Module | None = None,
lpips_loss_weight = 0.2
lpips_loss_weight = 0.2,
nd_rotary_kwargs: dict = dict(
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0.
)
):
super().__init__()
@ -589,6 +618,15 @@ class VideoTokenizer(Module):
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
)
# 3d rotations
self.spacetime_rotary = GoldenGateRoPENd(
dim_pos = 3,
heads = attn_heads,
dim_head = attn_dim_head,
**nd_rotary_kwargs
)
# attention related
self.attn_softclamp_value = attn_softclamp_value
@ -599,7 +637,7 @@ class VideoTokenizer(Module):
for _ in range(encoder_depth):
encoder_layers.append(ModuleList([
Attention(dim = dim, **attn_kwargs),
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
@ -630,7 +668,7 @@ class VideoTokenizer(Module):
for _ in range(decoder_depth):
decoder_layers.append(ModuleList([
Attention(dim = dim, **attn_kwargs),
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
@ -662,11 +700,9 @@ class VideoTokenizer(Module):
mask_patches = None,
return_all_losses = False
):
batch, time = video.shape[0], video.shape[2]
batch, _, time, height, width = video.shape
patch_size, device = self.patch_size, video.device
*_, height, width = video.shape
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
# to tokens
@ -677,6 +713,24 @@ class VideoTokenizer(Module):
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
# rotary positions
positions = stack(torch.meshgrid(
arange(time, device = device),
arange(num_patch_height, device = device),
arange(num_patch_width, device = device)
), dim = -1)
positions = rearrange(positions, 't h w p -> t (h w) p')
# give the latents an out of bounds position and assume the network will figure it out
positions = pad_at_dim(positions, (0, 1), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
positions = rearrange(positions, 't hw p -> (t hw) p')
rotary_pos_emb = self.spacetime_rotary(positions)
# masking
mask_patches = default(mask_patches, self.training)
@ -697,9 +751,9 @@ class VideoTokenizer(Module):
# add the latent
latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1])
latents = repeat(self.latent_token, 'd -> b t d', b = tokens.shape[0], t = tokens.shape[1])
tokens = cat((tokens, latents), dim = -2)
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
# pack time
@ -709,10 +763,12 @@ class VideoTokenizer(Module):
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
attend_fn = partial(naive_attend, **attend_kwargs)
# encoder
for attn, ff in self.encoder_layers:
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = ff(tokens) + tokens
tokens = self.encoder_norm(tokens)
@ -720,9 +776,10 @@ class VideoTokenizer(Module):
# latent bottleneck
tokens = inverse_pack_time(tokens)
tokens = tokens[..., -1, :]
latents = self.encoded_to_latents(tokens)
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
latents = self.encoded_to_latents(latents)
if return_latents:
return latents
@ -744,7 +801,8 @@ class VideoTokenizer(Module):
# decoder attention
for attn, ff in self.decoder_layers:
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = ff(tokens) + tokens
tokens = self.decoder_norm(tokens)
@ -959,6 +1017,14 @@ class DynamicsModel(Module):
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
# attend functions for space and time
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
space_attend = partial(naive_attend, causal = False, **attend_kwargs)
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
# attention
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
@ -967,17 +1033,11 @@ class DynamicsModel(Module):
# when is a axial time attention block, should be causal
attend_kwargs = dict()
if layer_is_time:
attend_kwargs.update(
softclamp_value = self.attn_softclamp_value,
causal = True
)
attend_fn = time_attend if layer_is_time else space_attend
# attention layer
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
tokens = attn(tokens, attend_fn = attend_fn) + tokens
tokens = post_attn_rearrange(tokens)