allow for the video tokenizer to accept any spatial dimensions by parameterizing the decoder positional embedding with an MLP

This commit is contained in:
lucidrains 2025-10-03 10:08:05 -07:00
parent 90bf19f076
commit 986bf4c529
2 changed files with 35 additions and 5 deletions

View File

@ -440,6 +440,7 @@ class VideoTokenizer(Module):
heads = 8,
),
ff_kwargs: dict = dict(),
decoder_pos_mlp_depth = 2,
channels = 3,
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
):
@ -494,6 +495,15 @@ class VideoTokenizer(Module):
# decoder
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
self.to_decoder_pos_emb = create_mlp(
dim_in = 2,
dim = dim * 2,
dim_out = dim,
depth = decoder_pos_mlp_depth,
)
decoder_layers = []
for _ in range(decoder_depth):
@ -511,7 +521,8 @@ class VideoTokenizer(Module):
return_latents = False,
mask_patches = None
):
patch_size = self.patch_size
batch, time = video.shape[0], video.shape[2]
patch_size, device = self.patch_size, video.device
*_, height, width = video.shape
@ -521,6 +532,10 @@ class VideoTokenizer(Module):
tokens = self.patch_to_tokens(video)
# get some dimensions
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
# masking
mask_patches = default(mask_patches, self.training)
@ -559,15 +574,29 @@ class VideoTokenizer(Module):
# latent bottleneck
tokens = inverse_pack_time(tokens)
tokens = tokens[..., -1, :]
latents = self.encoded_to_latents(tokens)
if return_latents:
latents = inverse_pack_time(latents)
return latents[..., -1, :]
return latents
tokens = self.latents_to_decoder(latents)
latent_tokens = self.latents_to_decoder(latents)
# decoder
# generate decoder positional embedding and concat the latent token
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
# decoder attention
for attn, ff in self.decoder_layers:
tokens = attn(tokens) + tokens

View File

@ -27,6 +27,7 @@ classifiers=[
dependencies = [
"accelerate",
"assoc-scan",
"einx>=0.3.0",
"einops>=0.8.1",
"hl-gauss-pytorch",