separate out the key from the value projections in attention for muon
This commit is contained in:
parent
ab5de6795f
commit
c5e64ff4ce
@ -922,7 +922,8 @@ class Attention(Module):
|
||||
dim_kv_inner = dim_head * heads
|
||||
|
||||
self.to_q = LinearNoBias(dim, dim_q_inner)
|
||||
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
|
||||
self.to_k = LinearNoBias(dim, dim_kv_inner)
|
||||
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
||||
self.to_out = LinearNoBias(dim_q_inner, dim)
|
||||
|
||||
# stability related
|
||||
@ -930,6 +931,12 @@ class Attention(Module):
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||
|
||||
def muon_parameters(self):
|
||||
return [
|
||||
*self.to_v.parameters(),
|
||||
*self.to_out.parameters(),
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens, # (b n d)
|
||||
@ -942,7 +949,7 @@ class Attention(Module):
|
||||
|
||||
tokens = self.norm(tokens)
|
||||
|
||||
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
||||
q, k, v = (self.to_q(tokens), self.to_k(tokens), self.to_v(tokens))
|
||||
|
||||
# split heads
|
||||
|
||||
@ -1004,6 +1011,12 @@ class SwiGLUFeedforward(Module):
|
||||
self.proj_in = Linear(dim, dim_inner * 2)
|
||||
self.proj_out = Linear(dim_inner, dim)
|
||||
|
||||
def muon_parameters(self):
|
||||
return [
|
||||
self.proj_in.weight,
|
||||
self.proj_out.weight,
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.15"
|
||||
version = "0.0.16"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user