make attention masking correct for dynamics model
This commit is contained in:
parent
986bf4c529
commit
77ad96ded2
@ -317,7 +317,8 @@ class Attention(Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
softclamp_value = 50.,
|
||||
pre_rmsnorm = True
|
||||
pre_rmsnorm = True,
|
||||
causal = False
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
@ -331,6 +332,10 @@ class Attention(Module):
|
||||
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
||||
self.to_out = LinearNoBias(dim_inner, dim)
|
||||
|
||||
# masking related
|
||||
|
||||
self.causal = causal
|
||||
|
||||
# stability related
|
||||
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||
@ -342,7 +347,8 @@ class Attention(Module):
|
||||
self,
|
||||
tokens, # (b n d)
|
||||
kv_cache = None,
|
||||
return_kv_cache = False
|
||||
return_kv_cache = False,
|
||||
mask = None
|
||||
):
|
||||
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||
|
||||
@ -379,6 +385,20 @@ class Attention(Module):
|
||||
|
||||
sim = sim * self.scale
|
||||
|
||||
# masking
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
sim = sim.masked_fill(~mask, mask_value)
|
||||
|
||||
if self.causal:
|
||||
i, j = sim.shape[-2:]
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, mask_value)
|
||||
|
||||
# attend
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
# aggregate
|
||||
@ -691,7 +711,7 @@ class DynamicsModel(Module):
|
||||
layers.append(ModuleList([
|
||||
rearrange_to_attend,
|
||||
rearrange_from_attend,
|
||||
Attention(dim = dim, **attn_kwargs),
|
||||
Attention(dim = dim, causal = is_time_block, **attn_kwargs),
|
||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||
]))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user