more muon prep

This commit is contained in:
lucidrains 2025-10-21 08:23:59 -07:00
parent b4763caff9
commit 15876d34cf
3 changed files with 23 additions and 1 deletions

View File

@ -1132,6 +1132,8 @@ class Attention(Module):
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
def muon_parameters(self):
# omit the queries and keys for now given what we learned from kimi 2 paper
return [
*self.to_v.parameters(),
*self.to_out.parameters(),
@ -1298,6 +1300,15 @@ class AxialSpaceTimeTransformer(Module):
self.num_special_spatial_tokens = num_special_spatial_tokens
def muon_parameters(self):
muon_params = []
for m in self.modules():
if isinstance(m, (Attention, SwiGLUFeedforward)):
muon_params.extend(m.muon_parameters())
return muon_params
def forward(
self,
tokens, # (b t s d)
@ -1523,6 +1534,12 @@ class VideoTokenizer(Module):
def device(self):
return self.zero.device
def muon_parameters(self):
return [
*self.encoder_transformer.muon_parameters(),
*self.decoder_transformer.muon_parameters()
]
@torch.no_grad()
def tokenize(
self,
@ -1906,6 +1923,9 @@ class DynamicsWorldModel(Module):
def device(self):
return self.zero.device
def muon_parameters(self):
return self.transformer.muon_parameters()
def get_times_from_signal_level(
self,
signal_levels,

View File

@ -3,6 +3,8 @@ from torch.nn import Module
from accelerate import Accelerator
from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsModel

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.52"
version = "0.0.53"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }