make attention masking correct for dynamics model
This commit is contained in:
parent
986bf4c529
commit
77ad96ded2
@ -317,7 +317,8 @@ class Attention(Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
softclamp_value = 50.,
|
softclamp_value = 50.,
|
||||||
pre_rmsnorm = True
|
pre_rmsnorm = True,
|
||||||
|
causal = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||||
@ -331,6 +332,10 @@ class Attention(Module):
|
|||||||
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
||||||
self.to_out = LinearNoBias(dim_inner, dim)
|
self.to_out = LinearNoBias(dim_inner, dim)
|
||||||
|
|
||||||
|
# masking related
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
# stability related
|
# stability related
|
||||||
|
|
||||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||||
@ -342,7 +347,8 @@ class Attention(Module):
|
|||||||
self,
|
self,
|
||||||
tokens, # (b n d)
|
tokens, # (b n d)
|
||||||
kv_cache = None,
|
kv_cache = None,
|
||||||
return_kv_cache = False
|
return_kv_cache = False,
|
||||||
|
mask = None
|
||||||
):
|
):
|
||||||
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
tokens, inverse_packed_batch = pack_one(tokens, '* n d')
|
||||||
|
|
||||||
@ -379,6 +385,20 @@ class Attention(Module):
|
|||||||
|
|
||||||
sim = sim * self.scale
|
sim = sim * self.scale
|
||||||
|
|
||||||
|
# masking
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
sim = sim.masked_fill(~mask, mask_value)
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
i, j = sim.shape[-2:]
|
||||||
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
||||||
|
sim = sim.masked_fill(causal_mask, mask_value)
|
||||||
|
|
||||||
|
# attend
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1)
|
attn = sim.softmax(dim = -1)
|
||||||
|
|
||||||
# aggregate
|
# aggregate
|
||||||
@ -691,7 +711,7 @@ class DynamicsModel(Module):
|
|||||||
layers.append(ModuleList([
|
layers.append(ModuleList([
|
||||||
rearrange_to_attend,
|
rearrange_to_attend,
|
||||||
rearrange_from_attend,
|
rearrange_from_attend,
|
||||||
Attention(dim = dim, **attn_kwargs),
|
Attention(dim = dim, causal = is_time_block, **attn_kwargs),
|
||||||
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user