From 77ad96ded2d30e525d4313221f5e2bb8451cc850 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 3 Oct 2025 11:18:44 -0700 Subject: [PATCH] make attention masking correct for dynamics model --- dreamer4/dreamer4.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a42861a..8701006 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -317,7 +317,8 @@ class Attention(Module): dim_head = 64, heads = 8, softclamp_value = 50., - pre_rmsnorm = True + pre_rmsnorm = True, + causal = False ): super().__init__() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() @@ -331,6 +332,10 @@ class Attention(Module): self.to_kv = LinearNoBias(dim, dim_inner * 2) self.to_out = LinearNoBias(dim_inner, dim) + # masking related + + self.causal = causal + # stability related self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) @@ -342,7 +347,8 @@ class Attention(Module): self, tokens, # (b n d) kv_cache = None, - return_kv_cache = False + return_kv_cache = False, + mask = None ): tokens, inverse_packed_batch = pack_one(tokens, '* n d') @@ -379,6 +385,20 @@ class Attention(Module): sim = sim * self.scale + # masking + + mask_value = -torch.finfo(sim.dtype).max + + if exists(mask): + sim = sim.masked_fill(~mask, mask_value) + + if self.causal: + i, j = sim.shape[-2:] + causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1) + sim = sim.masked_fill(causal_mask, mask_value) + + # attend + attn = sim.softmax(dim = -1) # aggregate @@ -691,7 +711,7 @@ class DynamicsModel(Module): layers.append(ModuleList([ rearrange_to_attend, rearrange_from_attend, - Attention(dim = dim, **attn_kwargs), + Attention(dim = dim, causal = is_time_block, **attn_kwargs), SwiGLUFeedforward(dim = dim, **ff_kwargs) ]))