take care of the MAE portion from Kaiming He
This commit is contained in:
parent
49082d8629
commit
e6c808960f
@ -362,7 +362,8 @@ class VideoTokenizer(Module):
|
||||
heads = 8,
|
||||
),
|
||||
ff_kwargs: dict = dict(),
|
||||
channels = 3
|
||||
channels = 3,
|
||||
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -371,6 +372,10 @@ class VideoTokenizer(Module):
|
||||
# special tokens
|
||||
|
||||
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
# mae masking - Kaiming He paper from long ago
|
||||
|
||||
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
||||
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
# patch and unpatch
|
||||
@ -425,7 +430,8 @@ class VideoTokenizer(Module):
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
return_latents = False
|
||||
return_latents = False,
|
||||
mask_patches = None
|
||||
):
|
||||
patch_size = self.patch_size
|
||||
|
||||
@ -437,6 +443,24 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens = self.patch_to_tokens(video)
|
||||
|
||||
# masking
|
||||
|
||||
mask_patches = default(mask_patches, self.training)
|
||||
|
||||
if mask_patches:
|
||||
min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
|
||||
|
||||
uniform_prob = torch.rand(tokens.shape[:2], device = tokens.device) # (b t)
|
||||
|
||||
mask_prob = uniform_prob * (max_mask_prob - min_mask_prob) + min_mask_prob
|
||||
|
||||
mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
|
||||
mask_patch = torch.bernoulli(mask_prob) == 1.
|
||||
|
||||
tokens = einx.where('b t vh vw, d, b t vh vw d', mask_patch, self.mask_token, tokens)
|
||||
|
||||
# pack space
|
||||
|
||||
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
||||
|
||||
# add the latent
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user