From 134532665653f1fea5b96f67de6581a020c44a44 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 20 Oct 2025 10:32:31 -0700 Subject: [PATCH] another measure for the attending to nothing issue --- dreamer4/dreamer4.py | 18 ++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index d16e12d..fbf7d29 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1078,6 +1078,7 @@ class Attention(Module): query_heads = None, heads = 8, pre_rmsnorm = True, + gate_values = True ): super().__init__() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() @@ -1100,6 +1101,17 @@ class Attention(Module): self.to_v = LinearNoBias(dim, dim_kv_inner) self.to_out = LinearNoBias(dim_q_inner, dim) + # alphafold gating per head, for attending to nothing + + self.to_gates = None + + if gate_values: + self.to_gates = Sequential( + LinearNoBias(dim, query_heads), + Rearrange('b n h -> b h n 1'), + nn.Sigmoid() + ) + # stability related self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) @@ -1153,6 +1165,12 @@ class Attention(Module): out = attend_fn(q, k, v) + # gate values + + if exists(self.to_gates): + gates = self.to_gates(tokens) + out = out * gates + # merge heads out = self.merge_heads(out) diff --git a/pyproject.toml b/pyproject.toml index e65affc..5275757 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.46" +version = "0.0.47" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }