From 51e08526046fa971efdc214ea837120f0e988c8a Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 2 Oct 2025 09:43:30 -0700 Subject: [PATCH] cleanup --- dreamer4/dreamer4.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 4b274f0..d494f77 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -450,9 +450,7 @@ class VideoTokenizer(Module): if mask_patches: min_mask_prob, max_mask_prob = self.per_image_patch_mask_prob - uniform_prob = torch.rand(tokens.shape[:2], device = tokens.device) # (b t) - - mask_prob = uniform_prob * (max_mask_prob - min_mask_prob) + min_mask_prob + mask_prob = torch.empty(tokens.shape[:2], device = tokens.device).uniform_(min_mask_prob, max_mask_prob) # (b t) mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3]) mask_patch = torch.bernoulli(mask_prob) == 1.