for the temporal attention in dynamics model, do rotary the traditional way

This commit is contained in:
lucidrains 2025-10-04 09:41:36 -07:00
parent 1b7f6e787d
commit e04f9ffec6

View File

@ -281,34 +281,49 @@ class GoldenGateRoPENd(Module):
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
return theta
return cat((theta, theta), dim = -1)
class Rotary1D(Module):
def __init__(
self,
dim_head,
theta = 10000.
):
super().__init__()
inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
self.register_buffer('inv_freq', inv_freq)
def forward(
self,
seq_len
):
device, dtype = self.inv_freq.device, self.inv_freq.dtype
t = torch.arange(seq_len, device = device).type(dtype)
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
return cat((freqs, freqs), dim = -1)
def apply_rotations(
theta, # (h n f)
qk
rotations, # (h n d) | (n d)
t # (b h n d)
):
rotary_heads = theta.shape[0]
heads, dtype = qk.shape[1], qk
heads, dtype = t.shape[1], t.dtype
t = t.float()
# handle gqa for rotary
if heads < rotary_heads:
if rotations.ndim > 2 and heads < rotations.shape[0]:
assert divisible_by(heads, rotary_heads)
groups = heads // rotary_heads
theta = repeat(theta, 'h ... -> (h g) ...', g = groups)
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
x1, x2 = t.chunk(2, dim = -1)
rotated_half_t = cat((-x2, x1), dim = -1)
# apply rotations
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
out = rearrange([x_out, y_out], 'split ... d -> ... (split d)')
return out.type_as(dtype)
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
return rotated.type(dtype)
# multi-head rmsnorm
@ -862,9 +877,9 @@ class DynamicsModel(Module):
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
time_block_every = 4, # every 4th block is time
attn_kwargs: dict = dict(
dim_head = 64,
heads = 8,
),
attn_dim_head = 64,
attn_softclamp_value = 50.,
ff_kwargs: dict = dict(),
loss_weight_fn: Callable = ramp_weight,
@ -908,6 +923,10 @@ class DynamicsModel(Module):
self.attn_softclamp_value = attn_softclamp_value
# time rotary embedding
self.time_rotary = Rotary1D(attn_dim_head)
# transformer
layers = []
@ -925,7 +944,7 @@ class DynamicsModel(Module):
layers.append(ModuleList([
rearrange_to_attend,
rearrange_from_attend,
Attention(dim = dim, **attn_kwargs),
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
@ -964,6 +983,8 @@ class DynamicsModel(Module):
latents = self.video_tokenizer.tokenize(video)
time = latents.shape[1]
# flow related
assert not (exists(signal_levels) ^ exists(step_sizes))
@ -1025,6 +1046,10 @@ class DynamicsModel(Module):
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
# rotary
rotary_pos_emb = self.time_rotary(time)
# attention
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
@ -1035,9 +1060,11 @@ class DynamicsModel(Module):
attend_fn = time_attend if layer_is_time else space_attend
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
# attention layer
tokens = attn(tokens, attend_fn = attend_fn) + tokens
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = post_attn_rearrange(tokens)