allow for the video tokenizer to accept any spatial dimensions by parameterizing the decoder positional embedding with an MLP
This commit is contained in:
parent
90bf19f076
commit
986bf4c529
@ -440,6 +440,7 @@ class VideoTokenizer(Module):
|
||||
heads = 8,
|
||||
),
|
||||
ff_kwargs: dict = dict(),
|
||||
decoder_pos_mlp_depth = 2,
|
||||
channels = 3,
|
||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||
):
|
||||
@ -494,6 +495,15 @@ class VideoTokenizer(Module):
|
||||
|
||||
# decoder
|
||||
|
||||
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
||||
|
||||
self.to_decoder_pos_emb = create_mlp(
|
||||
dim_in = 2,
|
||||
dim = dim * 2,
|
||||
dim_out = dim,
|
||||
depth = decoder_pos_mlp_depth,
|
||||
)
|
||||
|
||||
decoder_layers = []
|
||||
|
||||
for _ in range(decoder_depth):
|
||||
@ -511,7 +521,8 @@ class VideoTokenizer(Module):
|
||||
return_latents = False,
|
||||
mask_patches = None
|
||||
):
|
||||
patch_size = self.patch_size
|
||||
batch, time = video.shape[0], video.shape[2]
|
||||
patch_size, device = self.patch_size, video.device
|
||||
|
||||
*_, height, width = video.shape
|
||||
|
||||
@ -521,6 +532,10 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens = self.patch_to_tokens(video)
|
||||
|
||||
# get some dimensions
|
||||
|
||||
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
||||
|
||||
# masking
|
||||
|
||||
mask_patches = default(mask_patches, self.training)
|
||||
@ -559,15 +574,29 @@ class VideoTokenizer(Module):
|
||||
|
||||
# latent bottleneck
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
tokens = tokens[..., -1, :]
|
||||
|
||||
latents = self.encoded_to_latents(tokens)
|
||||
|
||||
if return_latents:
|
||||
latents = inverse_pack_time(latents)
|
||||
return latents[..., -1, :]
|
||||
return latents
|
||||
|
||||
tokens = self.latents_to_decoder(latents)
|
||||
latent_tokens = self.latents_to_decoder(latents)
|
||||
|
||||
# decoder
|
||||
# generate decoder positional embedding and concat the latent token
|
||||
|
||||
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
||||
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
||||
|
||||
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
||||
|
||||
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
||||
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
||||
|
||||
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
|
||||
|
||||
# decoder attention
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens) + tokens
|
||||
|
||||
@ -27,6 +27,7 @@ classifiers=[
|
||||
|
||||
dependencies = [
|
||||
"accelerate",
|
||||
"assoc-scan",
|
||||
"einx>=0.3.0",
|
||||
"einops>=0.8.1",
|
||||
"hl-gauss-pytorch",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user