separate out the key from the value projections in attention for muon
This commit is contained in:
parent
ab5de6795f
commit
c5e64ff4ce
@ -922,7 +922,8 @@ class Attention(Module):
|
|||||||
dim_kv_inner = dim_head * heads
|
dim_kv_inner = dim_head * heads
|
||||||
|
|
||||||
self.to_q = LinearNoBias(dim, dim_q_inner)
|
self.to_q = LinearNoBias(dim, dim_q_inner)
|
||||||
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
|
self.to_k = LinearNoBias(dim, dim_kv_inner)
|
||||||
|
self.to_v = LinearNoBias(dim, dim_kv_inner)
|
||||||
self.to_out = LinearNoBias(dim_q_inner, dim)
|
self.to_out = LinearNoBias(dim_q_inner, dim)
|
||||||
|
|
||||||
# stability related
|
# stability related
|
||||||
@ -930,6 +931,12 @@ class Attention(Module):
|
|||||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||||
|
|
||||||
|
def muon_parameters(self):
|
||||||
|
return [
|
||||||
|
*self.to_v.parameters(),
|
||||||
|
*self.to_out.parameters(),
|
||||||
|
]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
tokens, # (b n d)
|
tokens, # (b n d)
|
||||||
@ -942,7 +949,7 @@ class Attention(Module):
|
|||||||
|
|
||||||
tokens = self.norm(tokens)
|
tokens = self.norm(tokens)
|
||||||
|
|
||||||
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
q, k, v = (self.to_q(tokens), self.to_k(tokens), self.to_v(tokens))
|
||||||
|
|
||||||
# split heads
|
# split heads
|
||||||
|
|
||||||
@ -1004,6 +1011,12 @@ class SwiGLUFeedforward(Module):
|
|||||||
self.proj_in = Linear(dim, dim_inner * 2)
|
self.proj_in = Linear(dim, dim_inner * 2)
|
||||||
self.proj_out = Linear(dim_inner, dim)
|
self.proj_out = Linear(dim_inner, dim)
|
||||||
|
|
||||||
|
def muon_parameters(self):
|
||||||
|
return [
|
||||||
|
self.proj_in.weight,
|
||||||
|
self.proj_out.weight,
|
||||||
|
]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.15"
|
version = "0.0.16"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user