diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 6ea10ea..9b5a693 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -922,7 +922,8 @@ class Attention(Module): dim_kv_inner = dim_head * heads self.to_q = LinearNoBias(dim, dim_q_inner) - self.to_kv = LinearNoBias(dim, dim_kv_inner * 2) + self.to_k = LinearNoBias(dim, dim_kv_inner) + self.to_v = LinearNoBias(dim, dim_kv_inner) self.to_out = LinearNoBias(dim_q_inner, dim) # stability related @@ -930,6 +931,12 @@ class Attention(Module): self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) + def muon_parameters(self): + return [ + *self.to_v.parameters(), + *self.to_out.parameters(), + ] + def forward( self, tokens, # (b n d) @@ -942,7 +949,7 @@ class Attention(Module): tokens = self.norm(tokens) - q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1)) + q, k, v = (self.to_q(tokens), self.to_k(tokens), self.to_v(tokens)) # split heads @@ -1004,6 +1011,12 @@ class SwiGLUFeedforward(Module): self.proj_in = Linear(dim, dim_inner * 2) self.proj_out = Linear(dim_inner, dim) + def muon_parameters(self): + return [ + self.proj_in.weight, + self.proj_out.weight, + ] + def forward(self, x): x = self.norm(x) diff --git a/pyproject.toml b/pyproject.toml index 272a9ef..32e3c0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.15" +version = "0.0.16" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }