flesh out tokenizer even more
This commit is contained in:
parent
31c4aa28c7
commit
0285bba821
@ -24,7 +24,7 @@ from accelerate import Accelerator
|
||||
# vh, vw - video height and width
|
||||
|
||||
import einx
|
||||
from einops import einsum, rearrange, repeat, reduce
|
||||
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# flex attention - but will make sure it works if it is not available
|
||||
@ -51,9 +51,21 @@ def exists(v):
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def first(arr):
|
||||
return arr[0]
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def pack_one(t, pattern):
|
||||
packed, packed_shape = pack([t], pattern)
|
||||
|
||||
def inverse(out, inv_pattern = None):
|
||||
inv_pattern = default(inv_pattern, pattern)
|
||||
return first(unpack(out, packed_shape, inv_pattern))
|
||||
|
||||
return packed, inverse
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
@ -337,12 +349,21 @@ class VideoTokenizer(Module):
|
||||
dim,
|
||||
dim_latent,
|
||||
patch_size,
|
||||
encoder_depth = 4,
|
||||
decoder_depth = 4,
|
||||
attn_kwargs: dict = dict(),
|
||||
ff_kwargs: dict = dict(),
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# special tokens
|
||||
|
||||
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
|
||||
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
# patch and unpatch
|
||||
|
||||
dim_patch = channels * patch_size ** 2
|
||||
@ -361,7 +382,14 @@ class VideoTokenizer(Module):
|
||||
|
||||
encoder_layers = []
|
||||
|
||||
for _ in range(encoder_depth):
|
||||
encoder_layers.append(ModuleList([
|
||||
Attention(dim = dim, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
]))
|
||||
|
||||
self.encoder_layers = ModuleList(encoder_layers)
|
||||
self.encoder_norm = RMSNorm(dim)
|
||||
|
||||
# latents
|
||||
|
||||
@ -376,7 +404,14 @@ class VideoTokenizer(Module):
|
||||
|
||||
decoder_layers = []
|
||||
|
||||
for _ in range(decoder_depth):
|
||||
decoder_layers.append(ModuleList([
|
||||
Attention(dim = dim, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
]))
|
||||
|
||||
self.decoder_layers = ModuleList(decoder_layers)
|
||||
self.decoder_norm = RMSNorm(dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -393,6 +428,26 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens = self.patch_to_tokens(video)
|
||||
|
||||
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
||||
|
||||
# add the latent
|
||||
|
||||
latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1])
|
||||
|
||||
tokens = cat((tokens, latents), dim = -2)
|
||||
|
||||
# pack time
|
||||
|
||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||
|
||||
# encoder
|
||||
|
||||
for attn, ff in self.encoder_layers:
|
||||
tokens = attn(tokens) + tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.encoder_norm(tokens)
|
||||
|
||||
# latent bottleneck
|
||||
|
||||
latents = self.encoded_to_latents(tokens)
|
||||
@ -402,7 +457,27 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens = self.latents_to_decoder(latents)
|
||||
|
||||
# from tokens back to video
|
||||
# decoder
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens) + tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.decoder_norm(tokens)
|
||||
|
||||
# unpack time
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
|
||||
# excise latents
|
||||
|
||||
tokens = tokens[..., :-1, :]
|
||||
|
||||
# unpack space
|
||||
|
||||
tokens = inverse_pack_space(tokens)
|
||||
|
||||
# project back to patches
|
||||
|
||||
recon_video = self.tokens_to_patch(tokens)
|
||||
|
||||
|
||||
@ -24,3 +24,6 @@ def test_tokenizer():
|
||||
|
||||
loss = tokenizer(x)
|
||||
assert loss.numel() == 1
|
||||
|
||||
latents = tokenizer(x, return_latents = True)
|
||||
assert latents.shape[-1] == 32
|
||||
Loading…
x
Reference in New Issue
Block a user