From e6c808960fe69a75fc8c55ed8b105c59638becbd Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 2 Oct 2025 08:57:44 -0700 Subject: [PATCH] take care of the MAE portion from Kaiming He --- dreamer4/dreamer4.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 627b33c..3998a6b 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -362,7 +362,8 @@ class VideoTokenizer(Module): heads = 8, ), ff_kwargs: dict = dict(), - channels = 3 + channels = 3, + per_image_patch_mask_prob = (0., 0.9), # probability of patch masking appears to be per image probabilities drawn uniformly between 0. and 0.9 - if you are a phd student and think i'm mistakened, please open an issue ): super().__init__() @@ -371,6 +372,10 @@ class VideoTokenizer(Module): # special tokens self.latent_token = Parameter(torch.randn(dim) * 1e-2) + + # mae masking - Kaiming He paper from long ago + + self.per_image_patch_mask_prob = per_image_patch_mask_prob self.mask_token = Parameter(torch.randn(dim) * 1e-2) # patch and unpatch @@ -425,7 +430,8 @@ class VideoTokenizer(Module): def forward( self, video, # (b c t h w) - return_latents = False + return_latents = False, + mask_patches = None ): patch_size = self.patch_size @@ -437,6 +443,24 @@ class VideoTokenizer(Module): tokens = self.patch_to_tokens(video) + # masking + + mask_patches = default(mask_patches, self.training) + + if mask_patches: + min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob + + uniform_prob = torch.rand(tokens.shape[:2], device = tokens.device) # (b t) + + mask_prob = uniform_prob * (max_mask_prob - min_mask_prob) + min_mask_prob + + mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3]) + mask_patch = torch.bernoulli(mask_prob) == 1. + + tokens = einx.where('b t vh vw, d, b t vh vw d', mask_patch, self.mask_token, tokens) + + # pack space + tokens, inverse_pack_space = pack_one(tokens, 'b t * d') # add the latent