another measure for the attending to nothing issue
This commit is contained in:
parent
55574c054e
commit
1345326656
@ -1078,6 +1078,7 @@ class Attention(Module):
|
||||
query_heads = None,
|
||||
heads = 8,
|
||||
pre_rmsnorm = True,
|
||||
gate_values = True
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
@ -1100,6 +1101,17 @@ class Attention(Module):
|
||||
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
||||
self.to_out = LinearNoBias(dim_q_inner, dim)
|
||||
|
||||
# alphafold gating per head, for attending to nothing
|
||||
|
||||
self.to_gates = None
|
||||
|
||||
if gate_values:
|
||||
self.to_gates = Sequential(
|
||||
LinearNoBias(dim, query_heads),
|
||||
Rearrange('b n h -> b h n 1'),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# stability related
|
||||
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||
@ -1153,6 +1165,12 @@ class Attention(Module):
|
||||
|
||||
out = attend_fn(q, k, v)
|
||||
|
||||
# gate values
|
||||
|
||||
if exists(self.to_gates):
|
||||
gates = self.to_gates(tokens)
|
||||
out = out * gates
|
||||
|
||||
# merge heads
|
||||
|
||||
out = self.merge_heads(out)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.46"
|
||||
version = "0.0.47"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user