for the temporal attention in dynamics model, do rotary the traditional way
This commit is contained in:
parent
1b7f6e787d
commit
e04f9ffec6
@ -281,34 +281,49 @@ class GoldenGateRoPENd(Module):
|
||||
|
||||
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
|
||||
|
||||
return theta
|
||||
return cat((theta, theta), dim = -1)
|
||||
|
||||
class Rotary1D(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_head,
|
||||
theta = 10000.
|
||||
):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
seq_len
|
||||
):
|
||||
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
||||
|
||||
t = torch.arange(seq_len, device = device).type(dtype)
|
||||
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
||||
|
||||
return cat((freqs, freqs), dim = -1)
|
||||
|
||||
|
||||
def apply_rotations(
|
||||
theta, # (h n f)
|
||||
qk
|
||||
rotations, # (h n d) | (n d)
|
||||
t # (b h n d)
|
||||
):
|
||||
rotary_heads = theta.shape[0]
|
||||
heads, dtype = qk.shape[1], qk
|
||||
heads, dtype = t.shape[1], t.dtype
|
||||
t = t.float()
|
||||
|
||||
# handle gqa for rotary
|
||||
|
||||
if heads < rotary_heads:
|
||||
if rotations.ndim > 2 and heads < rotations.shape[0]:
|
||||
assert divisible_by(heads, rotary_heads)
|
||||
groups = heads // rotary_heads
|
||||
theta = repeat(theta, 'h ... -> (h g) ...', g = groups)
|
||||
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
||||
|
||||
x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
||||
x1, x2 = t.chunk(2, dim = -1)
|
||||
rotated_half_t = cat((-x2, x1), dim = -1)
|
||||
|
||||
# apply rotations
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
out = rearrange([x_out, y_out], 'split ... d -> ... (split d)')
|
||||
return out.type_as(dtype)
|
||||
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
||||
return rotated.type(dtype)
|
||||
|
||||
# multi-head rmsnorm
|
||||
|
||||
@ -862,9 +877,9 @@ class DynamicsModel(Module):
|
||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||
time_block_every = 4, # every 4th block is time
|
||||
attn_kwargs: dict = dict(
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
),
|
||||
attn_dim_head = 64,
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
@ -908,6 +923,10 @@ class DynamicsModel(Module):
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
|
||||
# time rotary embedding
|
||||
|
||||
self.time_rotary = Rotary1D(attn_dim_head)
|
||||
|
||||
# transformer
|
||||
|
||||
layers = []
|
||||
@ -925,7 +944,7 @@ class DynamicsModel(Module):
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
Attention(dim = dim, **attn_kwargs),
|
||||
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
]))
|
||||
|
||||
@ -964,6 +983,8 @@ class DynamicsModel(Module):
|
||||
|
||||
latents = self.video_tokenizer.tokenize(video)
|
||||
|
||||
time = latents.shape[1]
|
||||
|
||||
# flow related
|
||||
|
||||
assert not (exists(signal_levels) ^ exists(step_sizes))
|
||||
@ -1025,6 +1046,10 @@ class DynamicsModel(Module):
|
||||
|
||||
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
|
||||
|
||||
# rotary
|
||||
|
||||
rotary_pos_emb = self.time_rotary(time)
|
||||
|
||||
# attention
|
||||
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
@ -1035,9 +1060,11 @@ class DynamicsModel(Module):
|
||||
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
|
||||
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens = attn(tokens, attend_fn = attend_fn) + tokens
|
||||
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user