take care of the MAE portion from Kaiming He
This commit is contained in:
parent
49082d8629
commit
e6c808960f
@ -362,7 +362,8 @@ class VideoTokenizer(Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
),
|
),
|
||||||
ff_kwargs: dict = dict(),
|
ff_kwargs: dict = dict(),
|
||||||
channels = 3
|
channels = 3,
|
||||||
|
per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -371,6 +372,10 @@ class VideoTokenizer(Module):
|
|||||||
# special tokens
|
# special tokens
|
||||||
|
|
||||||
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
|
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
|
||||||
|
# mae masking - Kaiming He paper from long ago
|
||||||
|
|
||||||
|
self.per_image_patch_mask_prob = per_image_patch_mask_prob
|
||||||
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
|
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
|
||||||
# patch and unpatch
|
# patch and unpatch
|
||||||
@ -425,7 +430,8 @@ class VideoTokenizer(Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video, # (b c t h w)
|
video, # (b c t h w)
|
||||||
return_latents = False
|
return_latents = False,
|
||||||
|
mask_patches = None
|
||||||
):
|
):
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
|
|
||||||
@ -437,6 +443,24 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
tokens = self.patch_to_tokens(video)
|
tokens = self.patch_to_tokens(video)
|
||||||
|
|
||||||
|
# masking
|
||||||
|
|
||||||
|
mask_patches = default(mask_patches, self.training)
|
||||||
|
|
||||||
|
if mask_patches:
|
||||||
|
min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob
|
||||||
|
|
||||||
|
uniform_prob = torch.rand(tokens.shape[:2], device = tokens.device) # (b t)
|
||||||
|
|
||||||
|
mask_prob = uniform_prob * (max_mask_prob - min_mask_prob) + min_mask_prob
|
||||||
|
|
||||||
|
mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
|
||||||
|
mask_patch = torch.bernoulli(mask_prob) == 1.
|
||||||
|
|
||||||
|
tokens = einx.where('b t vh vw, d, b t vh vw d', mask_patch, self.mask_token, tokens)
|
||||||
|
|
||||||
|
# pack space
|
||||||
|
|
||||||
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
||||||
|
|
||||||
# add the latent
|
# add the latent
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user