rotate in the 3d rotary embeddings for the video tokenizer for both encoder / decoder
This commit is contained in:
parent
93f6738c9c
commit
1b7f6e787d
@ -27,6 +27,7 @@ from accelerate import Accelerator
|
||||
# f - frequencies (rotary)
|
||||
# p - positions (3 for spacetime in this work)
|
||||
# t - time
|
||||
# g - groups of query heads to key heads (gqa)
|
||||
# vc - video channels
|
||||
# vh, vw - video height and width
|
||||
|
||||
@ -66,6 +67,8 @@ def first(arr):
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def pack_one(t, pattern):
|
||||
packed, packed_shape = pack([t], pattern)
|
||||
|
||||
@ -75,6 +78,19 @@ def pack_one(t, pattern):
|
||||
|
||||
return packed, inverse
|
||||
|
||||
def pad_at_dim(
|
||||
t,
|
||||
pad: tuple[int, int],
|
||||
dim = -1,
|
||||
value = 0.
|
||||
):
|
||||
if pad == (0, 0):
|
||||
return t
|
||||
|
||||
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = ((0, 0) * dims_from_right)
|
||||
return F.pad(t, (*zeros, *pad), value = value)
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
@ -258,21 +274,30 @@ class GoldenGateRoPENd(Module):
|
||||
pos # (b n p)
|
||||
):
|
||||
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
freqs = rearrange(self.freqs, 'h f p -> h 1 f p')
|
||||
positions = rearrange(pos.float(), 'n p -> 1 n 1 p')
|
||||
|
||||
# thetas for freqs and positions (batch, head, seq, freq)
|
||||
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
theta = reduce(freqs * positions, 'h n f p -> h n f', 'sum')
|
||||
|
||||
return theta
|
||||
|
||||
def apply_rotations(
|
||||
theta # (b h n f)
|
||||
theta, # (h n f)
|
||||
qk
|
||||
):
|
||||
dtype = x
|
||||
rotary_heads = theta.shape[0]
|
||||
heads, dtype = qk.shape[1], qk
|
||||
|
||||
x, y = rearrange(x.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
||||
# handle gqa for rotary
|
||||
|
||||
if heads < rotary_heads:
|
||||
assert divisible_by(heads, rotary_heads)
|
||||
groups = heads // rotary_heads
|
||||
theta = repeat(theta, 'h ... -> (h g) ...', g = groups)
|
||||
|
||||
x, y = rearrange(qk.float(), '... (split d) -> split ... d', split = 2) # (b, h, n, f)
|
||||
|
||||
# apply rotations
|
||||
|
||||
@ -360,7 +385,7 @@ def softclamp_score_mod(value):
|
||||
|
||||
# todo - reuse the inner function from flex attn above with broadcasting
|
||||
|
||||
def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||
def block_causal_mask(seq_len, block_size, device = None):
|
||||
blocks = ceil(seq_len / block_size)
|
||||
|
||||
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
|
||||
@ -443,7 +468,6 @@ class Attention(Module):
|
||||
|
||||
# scaling, splitting and merging of heads
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||
|
||||
@ -464,13 +488,8 @@ class Attention(Module):
|
||||
tokens, # (b n d)
|
||||
kv_cache = None,
|
||||
return_kv_cache = False,
|
||||
attend_fn: Callable | None = None,
|
||||
attend_kwargs: dict = dict(
|
||||
softclamp_value = None,
|
||||
causal = False,
|
||||
mask = None,
|
||||
scale = None
|
||||
)
|
||||
rotary_pos_emb = None,
|
||||
attend_fn: Callable | None = None
|
||||
):
|
||||
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||
|
||||
@ -494,11 +513,17 @@ class Attention(Module):
|
||||
k = cat((ck, k), dim = -2)
|
||||
v = cat((cv, v), dim = -2)
|
||||
|
||||
# rotary
|
||||
|
||||
if exists(rotary_pos_emb):
|
||||
q = apply_rotations(rotary_pos_emb, q)
|
||||
k = apply_rotations(rotary_pos_emb, k)
|
||||
|
||||
# attention
|
||||
|
||||
attend_fn = default(attend_fn, naive_attend)
|
||||
|
||||
out = attend_fn(q, k, v, **attend_kwargs)
|
||||
out = attend_fn(q, k, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
@ -550,17 +575,21 @@ class VideoTokenizer(Module):
|
||||
patch_size,
|
||||
encoder_depth = 4,
|
||||
decoder_depth = 4,
|
||||
attn_kwargs: dict = dict(
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
),
|
||||
attn_kwargs: dict = dict(),
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
attn_softclamp_value = 50.,
|
||||
ff_kwargs: dict = dict(),
|
||||
decoder_pos_mlp_depth = 2,
|
||||
channels = 3,
|
||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||
lpips_loss_network: Module | None = None,
|
||||
lpips_loss_weight = 0.2
|
||||
lpips_loss_weight = 0.2,
|
||||
nd_rotary_kwargs: dict = dict(
|
||||
rope_min_freq = 1.,
|
||||
rope_max_freq = 10000.,
|
||||
rope_p_zero_freqs = 0.
|
||||
)
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -589,6 +618,15 @@ class VideoTokenizer(Module):
|
||||
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
||||
)
|
||||
|
||||
# 3d rotations
|
||||
|
||||
self.spacetime_rotary = GoldenGateRoPENd(
|
||||
dim_pos = 3,
|
||||
heads = attn_heads,
|
||||
dim_head = attn_dim_head,
|
||||
**nd_rotary_kwargs
|
||||
)
|
||||
|
||||
# attention related
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
@ -599,7 +637,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
for _ in range(encoder_depth):
|
||||
encoder_layers.append(ModuleList([
|
||||
Attention(dim = dim, **attn_kwargs),
|
||||
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
]))
|
||||
|
||||
@ -630,7 +668,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
for _ in range(decoder_depth):
|
||||
decoder_layers.append(ModuleList([
|
||||
Attention(dim = dim, **attn_kwargs),
|
||||
Attention(dim = dim, heads = attn_heads, dim_head = attn_dim_head, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
]))
|
||||
|
||||
@ -662,11 +700,9 @@ class VideoTokenizer(Module):
|
||||
mask_patches = None,
|
||||
return_all_losses = False
|
||||
):
|
||||
batch, time = video.shape[0], video.shape[2]
|
||||
batch, _, time, height, width = video.shape
|
||||
patch_size, device = self.patch_size, video.device
|
||||
|
||||
*_, height, width = video.shape
|
||||
|
||||
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
||||
|
||||
# to tokens
|
||||
@ -677,6 +713,24 @@ class VideoTokenizer(Module):
|
||||
|
||||
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
||||
|
||||
# rotary positions
|
||||
|
||||
positions = stack(torch.meshgrid(
|
||||
arange(time, device = device),
|
||||
arange(num_patch_height, device = device),
|
||||
arange(num_patch_width, device = device)
|
||||
), dim = -1)
|
||||
|
||||
positions = rearrange(positions, 't h w p -> t (h w) p')
|
||||
|
||||
# give the latents an out of bounds position and assume the network will figure it out
|
||||
|
||||
positions = pad_at_dim(positions, (0, 1), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||
|
||||
positions = rearrange(positions, 't hw p -> (t hw) p')
|
||||
|
||||
rotary_pos_emb = self.spacetime_rotary(positions)
|
||||
|
||||
# masking
|
||||
|
||||
mask_patches = default(mask_patches, self.training)
|
||||
@ -697,9 +751,9 @@ class VideoTokenizer(Module):
|
||||
|
||||
# add the latent
|
||||
|
||||
latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1])
|
||||
latents = repeat(self.latent_token, 'd -> b t d', b = tokens.shape[0], t = tokens.shape[1])
|
||||
|
||||
tokens = cat((tokens, latents), dim = -2)
|
||||
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
||||
|
||||
# pack time
|
||||
|
||||
@ -709,10 +763,12 @@ class VideoTokenizer(Module):
|
||||
|
||||
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
||||
|
||||
attend_fn = partial(naive_attend, **attend_kwargs)
|
||||
|
||||
# encoder
|
||||
|
||||
for attn, ff in self.encoder_layers:
|
||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.encoder_norm(tokens)
|
||||
@ -720,9 +776,10 @@ class VideoTokenizer(Module):
|
||||
# latent bottleneck
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
tokens = tokens[..., -1, :]
|
||||
|
||||
latents = self.encoded_to_latents(tokens)
|
||||
tokens, latents = unpack(tokens, packed_latent_shape, 'b t * d')
|
||||
|
||||
latents = self.encoded_to_latents(latents)
|
||||
|
||||
if return_latents:
|
||||
return latents
|
||||
@ -744,7 +801,8 @@ class VideoTokenizer(Module):
|
||||
# decoder attention
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.decoder_norm(tokens)
|
||||
@ -959,6 +1017,14 @@ class DynamicsModel(Module):
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_token], 'b t * d')
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
||||
|
||||
space_attend = partial(naive_attend, causal = False, **attend_kwargs)
|
||||
|
||||
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
|
||||
|
||||
# attention
|
||||
|
||||
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), layer_is_time in zip(self.layers, self.is_time):
|
||||
@ -967,17 +1033,11 @@ class DynamicsModel(Module):
|
||||
|
||||
# when is a axial time attention block, should be causal
|
||||
|
||||
attend_kwargs = dict()
|
||||
|
||||
if layer_is_time:
|
||||
attend_kwargs.update(
|
||||
softclamp_value = self.attn_softclamp_value,
|
||||
causal = True
|
||||
)
|
||||
attend_fn = time_attend if layer_is_time else space_attend
|
||||
|
||||
# attention layer
|
||||
|
||||
tokens = attn(tokens, attend_kwargs = attend_kwargs) + tokens
|
||||
tokens = attn(tokens, attend_fn = attend_fn) + tokens
|
||||
|
||||
tokens = post_attn_rearrange(tokens)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user