take care of the MAE portion from Kaiming He

This commit is contained in:
lucidrains 2025-10-02 08:57:44 -07:00
parent 49082d8629
commit e6c808960f

View File

@ -362,7 +362,8 @@ class VideoTokenizer(Module):
heads = 8,
),
ff_kwargs: dict = dict(),
channels = 3
channels = 3,
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
):
super().__init__()
@ -371,6 +372,10 @@ class VideoTokenizer(Module):
# special tokens
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
# mae masking - Kaiming He paper from long ago
self.per_image_patch_mask_prob = per_image_patch_mask_prob
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
# patch and unpatch
@ -425,7 +430,8 @@ class VideoTokenizer(Module):
def forward(
self,
video, # (b c t h w)
return_latents = False
return_latents = False,
mask_patches = None
):
patch_size = self.patch_size
@ -437,6 +443,24 @@ class VideoTokenizer(Module):
tokens = self.patch_to_tokens(video)
# masking
mask_patches = default(mask_patches, self.training)
if mask_patches:
min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
uniform_prob = torch.rand(tokens.shape[:2], device = tokens.device) # (b t)
mask_prob = uniform_prob * (max_mask_prob - min_mask_prob) + min_mask_prob
mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
mask_patch = torch.bernoulli(mask_prob) == 1.
tokens = einx.where('b t vh vw, d, b t vh vw d', mask_patch, self.mask_token, tokens)
# pack space
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
# add the latent