This commit is contained in:
lucidrains 2025-10-02 09:14:39 -07:00
parent e6c808960f
commit 0b503d880d

View File

@ -457,7 +457,7 @@ class VideoTokenizer(Module):
mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3])
mask_patch = torch.bernoulli(mask_prob) == 1.
tokens = einx.where('b t vh vw, d, b t vh vw d', mask_patch, self.mask_token, tokens)
tokens = einx.where('..., d, ... d', mask_patch, self.mask_token, tokens)
# pack space