rotate in the 3d rotary embeddings for the video tokenizer for both encoder / decoder
This commit is contained in:
parent
93f6738c9c
commit
1b7f6e787d
@ -27,6 +27,7 @@ from accelerate import Accelerator
|
|||||||
# f - frequencies (rotary)
|
# f - frequencies (rotary)
|
||||||
# p - positions (3 for spacetime in this work)
|
# p - positions (3 for spacetime in this work)
|
||||||
# t - time
|
# t - time
|
||||||
|
# g - groups of query heads to key heads (gqa)
|
||||||
# vc - video channels
|
# vc - video channels
|
||||||
# vh, vw - video height and width
|
# vh, vw - video height and width
|
||||||
|
|
||||||
@ -66,6 +67,8 @@ def first(arr):
|
|||||||
def divisible_by(num, den):
|
def divisible_by(num, den):
|
||||||
return (num % den) == 0
|
return (num % den) == 0
|
||||||
|
|
||||||
|
# tensor helpers
|
||||||
|
|
||||||
def pack_one(t, pattern):
|
def pack_one(t, pattern):
|
||||||
packed, packed_shape = pack([t], pattern)
|
packed, packed_shape = pack([t], pattern)
|
||||||
|
|
||||||
@ -75,6 +78,19 @@ def pack_one(t, pattern):
|
|||||||
|
|
||||||
return packed, inverse
|
return packed, inverse
|
||||||
|
|
||||||
|
def pad_at_dim(
|
||||||
|
t,
|
||||||
|
pad: tuple[int, int],
|
||||||
|
dim = -1,
|
||||||
|
value = 0.
|
||||||
|
):
|
||||||
|
if pad == (0, 0):
|
||||||
|
return t
|
||||||
|
|
||||||
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||||
|
zeros = ((0, 0) * dims_from_right)
|
||||||
|
return F.pad(t, (*zeros, *pad), value = value)
|
||||||
|
|
||||||
def l2norm(t):
|
def l2norm(t):
|
||||||
return F.normalize(t, dim = -1, p = 2)
|
return F.normalize(t, dim = -1, p = 2)
|
||||||
|
|
||||||
@ -258,21 +274,30 @@ class GoldenGateRoPENd(Module):
|
|||||||
pos # (b n p)
|
pos # (b n p)
|
||||||
):
|
):
|
||||||
|
|
||||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
freqs = rearrange(self.freqs, 'h f p -> h 1 f p')
|
||||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
positions = rearrange(pos.float(), 'n p -> 1 n 1 p')
|
||||||
|
|
||||||
# thetas for freqs and positions (batch, head, seq, freq)
|
# thetas for freqs and positions (batch, head, seq, freq)
|
||||||
|
|
||||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
|
||||||
|
|
||||||
return theta
|
return theta
|
||||||
|
|
||||||
def apply_rotations(
|
def apply_rotations(
|
||||||
theta # (b h n f)
|
theta, # (h n f)
|
||||||
|
qk
|
||||||
):
|
):
|
||||||
dtype = x
|
rotary_heads = theta.shape[0]
|
||||||
|
heads, dtype = qk.shape[1], qk
|
||||||
|
|
||||||
x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
# handle gqa for rotary
|
||||||
|
|
||||||
|
if heads < rotary_heads:
|
||||||
|
assert divisible_by(heads, rotary_heads)
|
||||||
|
groups = heads // rotary_heads
|
||||||
|
theta = repeat(theta, 'h ... -> (h g) ...', g = groups)
|
||||||
|
|
||||||
|
x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
||||||
|
|
||||||
# apply rotations
|
# apply rotations
|
||||||
|
|
||||||
@ -360,7 +385,7 @@ def softclamp_score_mod(value):
|
|||||||
|
|
||||||
# todo - reuse the inner function from flex attn above with broadcasting
|
# todo - reuse the inner function from flex attn above with broadcasting
|
||||||
|
|
||||||
def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
def block_causal_mask(seq_len, block_size, device = None):
|
||||||
blocks = ceil(seq_len / block_size)
|
blocks = ceil(seq_len / block_size)
|
||||||
|
|
||||||
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
|
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
|
||||||
@ -443,7 +468,6 @@ class Attention(Module):
|
|||||||
|
|
||||||
# scaling, splitting and merging of heads
|
# scaling, splitting and merging of heads
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
||||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||||
|
|
||||||
@ -464,13 +488,8 @@ class Attention(Module):
|
|||||||
tokens, # (b n d)
|
tokens, # (b n d)
|
||||||
kv_cache = None,
|
kv_cache = None,
|
||||||
return_kv_cache = False,
|
return_kv_cache = False,
|
||||||
attend_fn: Callable | None = None,
|
rotary_pos_emb = None,
|
||||||
attend_kwargs: dict = dict(
|
attend_fn: Callable | None = None
|
||||||
softclamp_value = None,
|
|
||||||
causal = False,
|
|
||||||
mask = None,
|
|
||||||
scale = None
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||||
|
|
||||||
@ -494,11 +513,17 @@ class Attention(Module):
|
|||||||
k = cat((ck, k), dim = -2)
|
k = cat((ck, k), dim = -2)
|
||||||
v = cat((cv, v), dim = -2)
|
v = cat((cv, v), dim = -2)
|
||||||
|
|
||||||
|
# rotary
|
||||||
|
|
||||||
|
if exists(rotary_pos_emb):
|
||||||
|
q = apply_rotations(rotary_pos_emb, q)
|
||||||
|
k = apply_rotations(rotary_pos_emb, k)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
attend_fn = default(attend_fn, naive_attend)
|
attend_fn = default(attend_fn, naive_attend)
|
||||||
|
|
||||||
out = attend_fn(q, k, v, **attend_kwargs)
|
out = attend_fn(q, k, v)
|
||||||
|
|
||||||
# merge heads
|
# merge heads
|
||||||
|
|
||||||
@ -550,17 +575,21 @@ class VideoTokenizer(Module):
|
|||||||
patch_size,
|
patch_size,
|
||||||
encoder_depth = 4,
|
encoder_depth = 4,
|
||||||
decoder_depth = 4,
|
decoder_depth = 4,
|
||||||
attn_kwargs: dict = dict(
|
attn_kwargs: dict = dict(),
|
||||||
dim_head = 64,
|
attn_dim_head = 64,
|
||||||
heads = 8,
|
attn_heads = 8,
|
||||||
),
|
|
||||||
attn_softclamp_value = 50.,
|
attn_softclamp_value = 50.,
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
decoder_pos_mlp_depth = 2,
|
decoder_pos_mlp_depth = 2,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||||
lpips_loss_network: Module | None = None,
|
lpips_loss_network: Module | None = None,
|
||||||
lpips_loss_weight = 0.2
|
lpips_loss_weight = 0.2,
|
||||||
|
nd_rotary_kwargs: dict = dict(
|
||||||
|
rope_min_freq = 1.,
|
||||||
|
rope_max_freq = 10000.,
|
||||||
|
rope_p_zero_freqs = 0.
|
||||||
|
)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -589,6 +618,15 @@ class VideoTokenizer(Module):
|
|||||||
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 3d rotations
|
||||||
|
|
||||||
|
self.spacetime_rotary = GoldenGateRoPENd(
|
||||||
|
dim_pos = 3,
|
||||||
|
heads = attn_heads,
|
||||||
|
dim_head = attn_dim_head,
|
||||||
|
**nd_rotary_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# attention related
|
# attention related
|
||||||
|
|
||||||
self.attn_softclamp_value = attn_softclamp_value
|
self.attn_softclamp_value = attn_softclamp_value
|
||||||
@ -599,7 +637,7 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
for _ in range(encoder_depth):
|
for _ in range(encoder_depth):
|
||||||
encoder_layers.append(ModuleList([
|
encoder_layers.append(ModuleList([
|
||||||
Attention(dim = dim, **attn_kwargs),
|
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
||||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@ -630,7 +668,7 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
for _ in range(decoder_depth):
|
for _ in range(decoder_depth):
|
||||||
decoder_layers.append(ModuleList([
|
decoder_layers.append(ModuleList([
|
||||||
Attention(dim = dim, **attn_kwargs),
|
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
||||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@ -662,11 +700,9 @@ class VideoTokenizer(Module):
|
|||||||
mask_patches = None,
|
mask_patches = None,
|
||||||
return_all_losses = False
|
return_all_losses = False
|
||||||
):
|
):
|
||||||
batch, time = video.shape[0], video.shape[2]
|
batch, _, time, height, width = video.shape
|
||||||
patch_size, device = self.patch_size, video.device
|
patch_size, device = self.patch_size, video.device
|
||||||
|
|
||||||
*_, height, width = video.shape
|
|
||||||
|
|
||||||
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
||||||
|
|
||||||
# to tokens
|
# to tokens
|
||||||
@ -677,6 +713,24 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
||||||
|
|
||||||
|
# rotary positions
|
||||||
|
|
||||||
|
positions = stack(torch.meshgrid(
|
||||||
|
arange(time, device = device),
|
||||||
|
arange(num_patch_height, device = device),
|
||||||
|
arange(num_patch_width, device = device)
|
||||||
|
), dim = -1)
|
||||||
|
|
||||||
|
positions = rearrange(positions, 't h w p -> t (h w) p')
|
||||||
|
|
||||||
|
# give the latents an out of bounds position and assume the network will figure it out
|
||||||
|
|
||||||
|
positions = pad_at_dim(positions, (0, 1), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||||
|
|
||||||
|
positions = rearrange(positions, 't hw p -> (t hw) p')
|
||||||
|
|
||||||
|
rotary_pos_emb = self.spacetime_rotary(positions)
|
||||||
|
|
||||||
# masking
|
# masking
|
||||||
|
|
||||||
mask_patches = default(mask_patches, self.training)
|
mask_patches = default(mask_patches, self.training)
|
||||||
@ -697,9 +751,9 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# add the latent
|
# add the latent
|
||||||
|
|
||||||
latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1])
|
latents = repeat(self.latent_token, 'd -> b t d', b = tokens.shape[0], t = tokens.shape[1])
|
||||||
|
|
||||||
tokens = cat((tokens, latents), dim = -2)
|
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
||||||
|
|
||||||
# pack time
|
# pack time
|
||||||
|
|
||||||
@ -709,10 +763,12 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
||||||
|
|
||||||
|
attend_fn = partial(naive_attend, **attend_kwargs)
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
|
|
||||||
for attn, ff in self.encoder_layers:
|
for attn, ff in self.encoder_layers:
|
||||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||||
tokens = ff(tokens) + tokens
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
tokens = self.encoder_norm(tokens)
|
tokens = self.encoder_norm(tokens)
|
||||||
@ -720,9 +776,10 @@ class VideoTokenizer(Module):
|
|||||||
# latent bottleneck
|
# latent bottleneck
|
||||||
|
|
||||||
tokens = inverse_pack_time(tokens)
|
tokens = inverse_pack_time(tokens)
|
||||||
tokens = tokens[..., -1, :]
|
|
||||||
|
|
||||||
latents = self.encoded_to_latents(tokens)
|
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
|
||||||
|
|
||||||
|
latents = self.encoded_to_latents(latents)
|
||||||
|
|
||||||
if return_latents:
|
if return_latents:
|
||||||
return latents
|
return latents
|
||||||
@ -744,7 +801,8 @@ class VideoTokenizer(Module):
|
|||||||
# decoder attention
|
# decoder attention
|
||||||
|
|
||||||
for attn, ff in self.decoder_layers:
|
for attn, ff in self.decoder_layers:
|
||||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||||
|
|
||||||
tokens = ff(tokens) + tokens
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
tokens = self.decoder_norm(tokens)
|
tokens = self.decoder_norm(tokens)
|
||||||
@ -959,6 +1017,14 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
||||||
|
|
||||||
|
# attend functions for space and time
|
||||||
|
|
||||||
|
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
||||||
|
|
||||||
|
space_attend = partial(naive_attend, causal = False, **attend_kwargs)
|
||||||
|
|
||||||
|
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||||
@ -967,17 +1033,11 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# when is a axial time attention block, should be causal
|
# when is a axial time attention block, should be causal
|
||||||
|
|
||||||
attend_kwargs = dict()
|
attend_fn = time_attend if layer_is_time else space_attend
|
||||||
|
|
||||||
if layer_is_time:
|
|
||||||
attend_kwargs.update(
|
|
||||||
softclamp_value = self.attn_softclamp_value,
|
|
||||||
causal = True
|
|
||||||
)
|
|
||||||
|
|
||||||
# attention layer
|
# attention layer
|
||||||
|
|
||||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
tokens = attn(tokens, attend_fn = attend_fn) + tokens
|
||||||
|
|
||||||
tokens = post_attn_rearrange(tokens)
|
tokens = post_attn_rearrange(tokens)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user