flesh out tokenizer even more

This commit is contained in:
lucidrains 2025-10-02 06:11:04 -07:00
parent 31c4aa28c7
commit 0285bba821
2 changed files with 80 additions and 2 deletions

View File

@ -24,7 +24,7 @@ from accelerate import Accelerator
# vh, vw - video height and width # vh, vw - video height and width
import einx import einx
from einops import einsum, rearrange, repeat, reduce from einops import einsum, rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
# flex attention - but will make sure it works if it is not available # flex attention - but will make sure it works if it is not available
@ -51,9 +51,21 @@ def exists(v):
def default(v, d): def default(v, d):
return v if exists(v) else d return v if exists(v) else d
def first(arr):
return arr[0]
def divisible_by(num, den): def divisible_by(num, den):
return (num % den) == 0 return (num % den) == 0
def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern)
def inverse(out, inv_pattern = None):
inv_pattern = default(inv_pattern, pattern)
return first(unpack(out, packed_shape, inv_pattern))
return packed, inverse
def l2norm(t): def l2norm(t):
return F.normalize(t, dim = -1, p = 2) return F.normalize(t, dim = -1, p = 2)
@ -337,12 +349,21 @@ class VideoTokenizer(Module):
dim, dim,
dim_latent, dim_latent,
patch_size, patch_size,
encoder_depth = 4,
decoder_depth = 4,
attn_kwargs: dict = dict(),
ff_kwargs: dict = dict(),
channels = 3 channels = 3
): ):
super().__init__() super().__init__()
self.patch_size = patch_size self.patch_size = patch_size
# special tokens
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
# patch and unpatch # patch and unpatch
dim_patch = channels * patch_size ** 2 dim_patch = channels * patch_size ** 2
@ -361,7 +382,14 @@ class VideoTokenizer(Module):
encoder_layers = [] encoder_layers = []
for _ in range(encoder_depth):
encoder_layers.append(ModuleList([
Attention(dim = dim, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
self.encoder_layers = ModuleList(encoder_layers) self.encoder_layers = ModuleList(encoder_layers)
self.encoder_norm = RMSNorm(dim)
# latents # latents
@ -376,7 +404,14 @@ class VideoTokenizer(Module):
decoder_layers = [] decoder_layers = []
for _ in range(decoder_depth):
decoder_layers.append(ModuleList([
Attention(dim = dim, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))
self.decoder_layers = ModuleList(decoder_layers) self.decoder_layers = ModuleList(decoder_layers)
self.decoder_norm = RMSNorm(dim)
def forward( def forward(
self, self,
@ -393,6 +428,26 @@ class VideoTokenizer(Module):
tokens = self.patch_to_tokens(video) tokens = self.patch_to_tokens(video)
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
# add the latent
latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1])
tokens = cat((tokens, latents), dim = -2)
# pack time
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
# encoder
for attn, ff in self.encoder_layers:
tokens = attn(tokens) + tokens
tokens = ff(tokens) + tokens
tokens = self.encoder_norm(tokens)
# latent bottleneck # latent bottleneck
latents = self.encoded_to_latents(tokens) latents = self.encoded_to_latents(tokens)
@ -402,7 +457,27 @@ class VideoTokenizer(Module):
tokens = self.latents_to_decoder(latents) tokens = self.latents_to_decoder(latents)
# from tokens back to video # decoder
for attn, ff in self.decoder_layers:
tokens = attn(tokens) + tokens
tokens = ff(tokens) + tokens
tokens = self.decoder_norm(tokens)
# unpack time
tokens = inverse_pack_time(tokens)
# excise latents
tokens = tokens[..., :-1, :]
# unpack space
tokens = inverse_pack_space(tokens)
# project back to patches
recon_video = self.tokens_to_patch(tokens) recon_video = self.tokens_to_patch(tokens)

View File

@ -24,3 +24,6 @@ def test_tokenizer():
loss = tokenizer(x) loss = tokenizer(x)
assert loss.numel() == 1 assert loss.numel() == 1
latents = tokenizer(x, return_latents = True)
assert latents.shape[-1] == 32