diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 3998a6b..4b274f0 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -457,7 +457,7 @@ class VideoTokenizer(Module): mask_prob = repeat(mask_prob, 'b t -> b t vh vw', vh = tokens.shape[2], vw = tokens.shape[3]) mask_patch = torch.bernoulli(mask_prob) == 1. - tokens = einx.where('b t vh vw, d, b t vh vw d', mask_patch, self.mask_token, tokens) + tokens = einx.where('..., d, ... d', mask_patch, self.mask_token, tokens) # pack space