From 15876d34cf8ec27f17023576efa52bf523a492a9 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 08:23:59 -0700 Subject: [PATCH] more muon prep --- dreamer4/dreamer4.py | 20 ++++++++++++++++++++ dreamer4/trainers.py | 2 ++ pyproject.toml | 2 +- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index dd11a69..3347187 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1132,6 +1132,8 @@ class Attention(Module): self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity() def muon_parameters(self): + # omit the queries and keys for now given what we learned from kimi 2 paper + return [ *self.to_v.parameters(), *self.to_out.parameters(), @@ -1298,6 +1300,15 @@ class AxialSpaceTimeTransformer(Module): self.num_special_spatial_tokens = num_special_spatial_tokens + def muon_parameters(self): + muon_params = [] + + for m in self.modules(): + if isinstance(m, (Attention, SwiGLUFeedforward)): + muon_params.extend(m.muon_parameters()) + + return muon_params + def forward( self, tokens, # (b t s d) @@ -1523,6 +1534,12 @@ class VideoTokenizer(Module): def device(self): return self.zero.device + def muon_parameters(self): + return [ + *self.encoder_transformer.muon_parameters(), + *self.decoder_transformer.muon_parameters() + ] + @torch.no_grad() def tokenize( self, @@ -1906,6 +1923,9 @@ class DynamicsWorldModel(Module): def device(self): return self.zero.device + def muon_parameters(self): + return self.transformer.muon_parameters() + def get_times_from_signal_level( self, signal_levels, diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index c450ee5..755da23 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -3,6 +3,8 @@ from torch.nn import Module from accelerate import Accelerator +from adam_atan2_pytorch import MuonAdamAtan2 + from dreamer4.dreamer4 import ( VideoTokenizer, DynamicsModel diff --git a/pyproject.toml b/pyproject.toml index 1a76c38..307a4b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.52" +version = "0.0.53" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }