separate out the key from the value projections in attention for muon

This commit is contained in:
lucidrains 2025-10-12 09:42:22 -07:00
parent ab5de6795f
commit c5e64ff4ce
2 changed files with 16 additions and 3 deletions

View File

@ -922,7 +922,8 @@ class Attention(Module):
dim_kv_inner = dim_head * heads
self.to_q = LinearNoBias(dim, dim_q_inner)
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
self.to_k = LinearNoBias(dim, dim_kv_inner)
self.to_v = LinearNoBias(dim, dim_kv_inner)
self.to_out = LinearNoBias(dim_q_inner, dim)
# stability related
@ -930,6 +931,12 @@ class Attention(Module):
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
def muon_parameters(self):
return [
*self.to_v.parameters(),
*self.to_out.parameters(),
]
def forward(
self,
tokens, # (b n d)
@ -942,7 +949,7 @@ class Attention(Module):
tokens = self.norm(tokens)
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
q, k, v = (self.to_q(tokens), self.to_k(tokens), self.to_v(tokens))
# split heads
@ -1004,6 +1011,12 @@ class SwiGLUFeedforward(Module):
self.proj_in = Linear(dim, dim_inner * 2)
self.proj_out = Linear(dim_inner, dim)
def muon_parameters(self):
return [
self.proj_in.weight,
self.proj_out.weight,
]
def forward(self, x):
x = self.norm(x)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.15"
version = "0.0.16"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }