another measure for the attending to nothing issue

This commit is contained in:
lucidrains 2025-10-20 10:32:31 -07:00
parent 55574c054e
commit 1345326656
2 changed files with 19 additions and 1 deletions

View File

@ -1078,6 +1078,7 @@ class Attention(Module):
query_heads = None, query_heads = None,
heads = 8, heads = 8,
pre_rmsnorm = True, pre_rmsnorm = True,
gate_values = True
): ):
super().__init__() super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
@ -1100,6 +1101,17 @@ class Attention(Module):
self.to_v = LinearNoBias(dim, dim_kv_inner) self.to_v = LinearNoBias(dim, dim_kv_inner)
self.to_out = LinearNoBias(dim_q_inner, dim) self.to_out = LinearNoBias(dim_q_inner, dim)
# alphafold gating per head, for attending to nothing
self.to_gates = None
if gate_values:
self.to_gates = Sequential(
LinearNoBias(dim, query_heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
)
# stability related # stability related
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
@ -1153,6 +1165,12 @@ class Attention(Module):
out = attend_fn(q, k, v) out = attend_fn(q, k, v)
# gate values
if exists(self.to_gates):
gates = self.to_gates(tokens)
out = out * gates
# merge heads # merge heads
out = self.merge_heads(out) out = self.merge_heads(out)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.46" version = "0.0.47"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }