for the temporal attention in dynamics model, do rotary the traditional way
This commit is contained in:
parent
1b7f6e787d
commit
e04f9ffec6
@ -281,34 +281,49 @@ class GoldenGateRoPENd(Module):
|
|||||||
|
|
||||||
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
|
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
|
||||||
|
|
||||||
return theta
|
return cat((theta, theta), dim = -1)
|
||||||
|
|
||||||
|
class Rotary1D(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_head,
|
||||||
|
theta = 10000.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (theta ** (arange(0, dim_head, 2).float() / dim_head))
|
||||||
|
self.register_buffer('inv_freq', inv_freq)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
seq_len
|
||||||
|
):
|
||||||
|
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
||||||
|
|
||||||
|
t = torch.arange(seq_len, device = device).type(dtype)
|
||||||
|
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
||||||
|
|
||||||
|
return cat((freqs, freqs), dim = -1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotations(
|
def apply_rotations(
|
||||||
theta, # (h n f)
|
rotations, # (h n d) | (n d)
|
||||||
qk
|
t # (b h n d)
|
||||||
):
|
):
|
||||||
rotary_heads = theta.shape[0]
|
heads, dtype = t.shape[1], t.dtype
|
||||||
heads, dtype = qk.shape[1], qk
|
t = t.float()
|
||||||
|
|
||||||
# handle gqa for rotary
|
# handle gqa for rotary
|
||||||
|
|
||||||
if heads < rotary_heads:
|
if rotations.ndim > 2 and heads < rotations.shape[0]:
|
||||||
assert divisible_by(heads, rotary_heads)
|
assert divisible_by(heads, rotary_heads)
|
||||||
groups = heads // rotary_heads
|
groups = heads // rotary_heads
|
||||||
theta = repeat(theta, 'h ... -> (h g) ...', g = groups)
|
rotations = repeat(rotations, 'h ... -> (h g) ...', g = groups)
|
||||||
|
|
||||||
x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
x1, x2 = t.chunk(2, dim = -1)
|
||||||
|
rotated_half_t = cat((-x2, x1), dim = -1)
|
||||||
|
|
||||||
# apply rotations
|
rotated = t * rotations.cos() + rotated_half_t * rotations.sin()
|
||||||
|
return rotated.type(dtype)
|
||||||
cos_theta = torch.cos(theta)
|
|
||||||
sin_theta = torch.sin(theta)
|
|
||||||
|
|
||||||
x_out = x * cos_theta - y * sin_theta
|
|
||||||
y_out = x * sin_theta + y * cos_theta
|
|
||||||
|
|
||||||
out = rearrange([x_out, y_out], 'split ... d -> ... (split d)')
|
|
||||||
return out.type_as(dtype)
|
|
||||||
|
|
||||||
# multi-head rmsnorm
|
# multi-head rmsnorm
|
||||||
|
|
||||||
@ -862,9 +877,9 @@ class DynamicsModel(Module):
|
|||||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||||
time_block_every = 4, # every 4th block is time
|
time_block_every = 4, # every 4th block is time
|
||||||
attn_kwargs: dict = dict(
|
attn_kwargs: dict = dict(
|
||||||
dim_head = 64,
|
|
||||||
heads = 8,
|
heads = 8,
|
||||||
),
|
),
|
||||||
|
attn_dim_head = 64,
|
||||||
attn_softclamp_value = 50.,
|
attn_softclamp_value = 50.,
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
loss_weight_fn: Callable = ramp_weight,
|
loss_weight_fn: Callable = ramp_weight,
|
||||||
@ -908,6 +923,10 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
self.attn_softclamp_value = attn_softclamp_value
|
self.attn_softclamp_value = attn_softclamp_value
|
||||||
|
|
||||||
|
# time rotary embedding
|
||||||
|
|
||||||
|
self.time_rotary = Rotary1D(attn_dim_head)
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
@ -925,7 +944,7 @@ class DynamicsModel(Module):
|
|||||||
layers.append(ModuleList([
|
layers.append(ModuleList([
|
||||||
rearrange_to_attend,
|
rearrange_to_attend,
|
||||||
rearrange_from_attend,
|
rearrange_from_attend,
|
||||||
Attention(dim = dim, **attn_kwargs),
|
Attention(dim = dim, dim_head = attn_dim_head, **attn_kwargs),
|
||||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@ -964,6 +983,8 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
latents = self.video_tokenizer.tokenize(video)
|
latents = self.video_tokenizer.tokenize(video)
|
||||||
|
|
||||||
|
time = latents.shape[1]
|
||||||
|
|
||||||
# flow related
|
# flow related
|
||||||
|
|
||||||
assert not (exists(signal_levels) ^ exists(step_sizes))
|
assert not (exists(signal_levels) ^ exists(step_sizes))
|
||||||
@ -1025,6 +1046,10 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
|
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
|
||||||
|
|
||||||
|
# rotary
|
||||||
|
|
||||||
|
rotary_pos_emb = self.time_rotary(time)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||||
@ -1035,9 +1060,11 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
attend_fn = time_attend if layer_is_time else space_attend
|
attend_fn = time_attend if layer_is_time else space_attend
|
||||||
|
|
||||||
|
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
|
||||||
|
|
||||||
# attention layer
|
# attention layer
|
||||||
|
|
||||||
tokens = attn(tokens, attend_fn = attend_fn) + tokens
|
tokens = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||||
|
|
||||||
tokens = post_attn_rearrange(tokens)
|
tokens = post_attn_rearrange(tokens)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user