flesh out tokenizer even more

This commit is contained in:
lucidrains 2025-10-02 06:11:04 -07:00
parent 31c4aa28c7
commit 0285bba821
2 changed files with 80 additions and 2 deletions

View File

@ -24,7 +24,7 @@ from accelerate import Accelerator
# vh, vw - video height and width
import einx
from einops import einsum, rearrange, repeat, reduce
from einops import einsum, rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange
# flex attention - but will make sure it works if it is not available
@ -51,9 +51,21 @@ def exists(v):
def default(v, d):
return v if exists(v) else d
def first(arr):
return arr[0]
def divisible_by(num, den):
return (num % den) == 0
def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern)
def inverse(out, inv_pattern = None):
inv_pattern = default(inv_pattern, pattern)
return first(unpack(out, packed_shape, inv_pattern))
return packed, inverse
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
@ -337,12 +349,21 @@ class VideoTokenizer(Module):
dim,
dim_latent,
patch_size,
encoder_depth = 4,
decoder_depth = 4,
attn_kwargs: dict = dict(),
ff_kwargs: dict = dict(),
channels = 3
):
super().__init__()
self.patch_size = patch_size
# special tokens
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
# patch and unpatch
dim_patch = channels * patch_size ** 2
@ -361,7 +382,14 @@ class VideoTokenizer(Module):
encoder_layers = []
for _ in range(encoder_depth):
encoder_layers.append(ModuleList([
Attention(dim = dim, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
self.encoder_layers = ModuleList(encoder_layers)
self.encoder_norm = RMSNorm(dim)
# latents
@ -376,7 +404,14 @@ class VideoTokenizer(Module):
decoder_layers = []
for _ in range(decoder_depth):
decoder_layers.append(ModuleList([
Attention(dim = dim, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
self.decoder_layers = ModuleList(decoder_layers)
self.decoder_norm = RMSNorm(dim)
def forward(
self,
@ -393,6 +428,26 @@ class VideoTokenizer(Module):
tokens = self.patch_to_tokens(video)
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
# add the latent
latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1])
tokens = cat((tokens, latents), dim = -2)
# pack time
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
# encoder
for attn, ff in self.encoder_layers:
tokens = attn(tokens) + tokens
tokens = ff(tokens) + tokens
tokens = self.encoder_norm(tokens)
# latent bottleneck
latents = self.encoded_to_latents(tokens)
@ -402,7 +457,27 @@ class VideoTokenizer(Module):
tokens = self.latents_to_decoder(latents)
# from tokens back to video
# decoder
for attn, ff in self.decoder_layers:
tokens = attn(tokens) + tokens
tokens = ff(tokens) + tokens
tokens = self.decoder_norm(tokens)
# unpack time
tokens = inverse_pack_time(tokens)
# excise latents
tokens = tokens[..., :-1, :]
# unpack space
tokens = inverse_pack_space(tokens)
# project back to patches
recon_video = self.tokens_to_patch(tokens)

View File

@ -24,3 +24,6 @@ def test_tokenizer():
loss = tokenizer(x)
assert loss.numel() == 1
latents = tokenizer(x, return_latents = True)
assert latents.shape[-1] == 32