make attention masking correct for dynamics model

This commit is contained in:
lucidrains 2025-10-03 11:18:44 -07:00
parent 986bf4c529
commit 77ad96ded2

View File

@ -317,7 +317,8 @@ class Attention(Module):
dim_head = 64,
heads = 8,
softclamp_value = 50.,
pre_rmsnorm = True
pre_rmsnorm = True,
causal = False
):
super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@ -331,6 +332,10 @@ class Attention(Module):
self.to_kv = LinearNoBias(dim, dim_inner * 2)
self.to_out = LinearNoBias(dim_inner, dim)
# masking related
self.causal = causal
# stability related
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
@ -342,7 +347,8 @@ class Attention(Module):
self,
tokens, # (b n d)
kv_cache = None,
return_kv_cache = False
return_kv_cache = False,
mask = None
):
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
@ -379,6 +385,20 @@ class Attention(Module):
sim = sim * self.scale
# masking
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
sim = sim.masked_fill(~mask, mask_value)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)
# attend
attn = sim.softmax(dim = -1)
# aggregate
@ -691,7 +711,7 @@ class DynamicsModel(Module):
layers.append(ModuleList([
rearrange_to_attend,
rearrange_from_attend,
Attention(dim = dim, **attn_kwargs),
Attention(dim = dim, causal = is_time_block, **attn_kwargs),
SwiGLUFeedforward(dim = dim, **ff_kwargs)
]))