another measure for the attending to nothing issue
This commit is contained in:
parent
55574c054e
commit
1345326656
@ -1078,6 +1078,7 @@ class Attention(Module):
|
|||||||
query_heads = None,
|
query_heads = None,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
pre_rmsnorm = True,
|
pre_rmsnorm = True,
|
||||||
|
gate_values = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||||
@ -1100,6 +1101,17 @@ class Attention(Module):
|
|||||||
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
||||||
self.to_out = LinearNoBias(dim_q_inner, dim)
|
self.to_out = LinearNoBias(dim_q_inner, dim)
|
||||||
|
|
||||||
|
# alphafold gating per head, for attending to nothing
|
||||||
|
|
||||||
|
self.to_gates = None
|
||||||
|
|
||||||
|
if gate_values:
|
||||||
|
self.to_gates = Sequential(
|
||||||
|
LinearNoBias(dim, query_heads),
|
||||||
|
Rearrange('b n h -> b h n 1'),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
# stability related
|
# stability related
|
||||||
|
|
||||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||||
@ -1153,6 +1165,12 @@ class Attention(Module):
|
|||||||
|
|
||||||
out = attend_fn(q, k, v)
|
out = attend_fn(q, k, v)
|
||||||
|
|
||||||
|
# gate values
|
||||||
|
|
||||||
|
if exists(self.to_gates):
|
||||||
|
gates = self.to_gates(tokens)
|
||||||
|
out = out * gates
|
||||||
|
|
||||||
# merge heads
|
# merge heads
|
||||||
|
|
||||||
out = self.merge_heads(out)
|
out = self.merge_heads(out)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.46"
|
version = "0.0.47"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user