allow for the video tokenizer to accept any spatial dimensions by parameterizing the decoder positional embedding with an MLP
This commit is contained in:
parent
90bf19f076
commit
986bf4c529
@ -440,6 +440,7 @@ class VideoTokenizer(Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
),
|
),
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
|
decoder_pos_mlp_depth = 2,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||||
):
|
):
|
||||||
@ -494,6 +495,15 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
|
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
||||||
|
|
||||||
|
self.to_decoder_pos_emb = create_mlp(
|
||||||
|
dim_in = 2,
|
||||||
|
dim = dim * 2,
|
||||||
|
dim_out = dim,
|
||||||
|
depth = decoder_pos_mlp_depth,
|
||||||
|
)
|
||||||
|
|
||||||
decoder_layers = []
|
decoder_layers = []
|
||||||
|
|
||||||
for _ in range(decoder_depth):
|
for _ in range(decoder_depth):
|
||||||
@ -511,7 +521,8 @@ class VideoTokenizer(Module):
|
|||||||
return_latents = False,
|
return_latents = False,
|
||||||
mask_patches = None
|
mask_patches = None
|
||||||
):
|
):
|
||||||
patch_size = self.patch_size
|
batch, time = video.shape[0], video.shape[2]
|
||||||
|
patch_size, device = self.patch_size, video.device
|
||||||
|
|
||||||
*_, height, width = video.shape
|
*_, height, width = video.shape
|
||||||
|
|
||||||
@ -521,6 +532,10 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
tokens = self.patch_to_tokens(video)
|
tokens = self.patch_to_tokens(video)
|
||||||
|
|
||||||
|
# get some dimensions
|
||||||
|
|
||||||
|
num_patch_height, num_patch_width, _ = tokens.shape[-3:]
|
||||||
|
|
||||||
# masking
|
# masking
|
||||||
|
|
||||||
mask_patches = default(mask_patches, self.training)
|
mask_patches = default(mask_patches, self.training)
|
||||||
@ -559,15 +574,29 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# latent bottleneck
|
# latent bottleneck
|
||||||
|
|
||||||
|
tokens = inverse_pack_time(tokens)
|
||||||
|
tokens = tokens[..., -1, :]
|
||||||
|
|
||||||
latents = self.encoded_to_latents(tokens)
|
latents = self.encoded_to_latents(tokens)
|
||||||
|
|
||||||
if return_latents:
|
if return_latents:
|
||||||
latents = inverse_pack_time(latents)
|
return latents
|
||||||
return latents[..., -1, :]
|
|
||||||
|
|
||||||
tokens = self.latents_to_decoder(latents)
|
latent_tokens = self.latents_to_decoder(latents)
|
||||||
|
|
||||||
# decoder
|
# generate decoder positional embedding and concat the latent token
|
||||||
|
|
||||||
|
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
||||||
|
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
||||||
|
|
||||||
|
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
||||||
|
|
||||||
|
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
||||||
|
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
||||||
|
|
||||||
|
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
|
||||||
|
|
||||||
|
# decoder attention
|
||||||
|
|
||||||
for attn, ff in self.decoder_layers:
|
for attn, ff in self.decoder_layers:
|
||||||
tokens = attn(tokens) + tokens
|
tokens = attn(tokens) + tokens
|
||||||
|
|||||||
@ -27,6 +27,7 @@ classifiers=[
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
|
"assoc-scan",
|
||||||
"einx>=0.3.0",
|
"einx>=0.3.0",
|
||||||
"einops>=0.8.1",
|
"einops>=0.8.1",
|
||||||
"hl-gauss-pytorch",
|
"hl-gauss-pytorch",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user