more muon prep
This commit is contained in:
parent
b4763caff9
commit
15876d34cf
@ -1132,6 +1132,8 @@ class Attention(Module):
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
||||
|
||||
def muon_parameters(self):
|
||||
# omit the queries and keys for now given what we learned from kimi 2 paper
|
||||
|
||||
return [
|
||||
*self.to_v.parameters(),
|
||||
*self.to_out.parameters(),
|
||||
@ -1298,6 +1300,15 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
self.num_special_spatial_tokens = num_special_spatial_tokens
|
||||
|
||||
def muon_parameters(self):
|
||||
muon_params = []
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, (Attention, SwiGLUFeedforward)):
|
||||
muon_params.extend(m.muon_parameters())
|
||||
|
||||
return muon_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens, # (b t s d)
|
||||
@ -1523,6 +1534,12 @@ class VideoTokenizer(Module):
|
||||
def device(self):
|
||||
return self.zero.device
|
||||
|
||||
def muon_parameters(self):
|
||||
return [
|
||||
*self.encoder_transformer.muon_parameters(),
|
||||
*self.decoder_transformer.muon_parameters()
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def tokenize(
|
||||
self,
|
||||
@ -1906,6 +1923,9 @@ class DynamicsWorldModel(Module):
|
||||
def device(self):
|
||||
return self.zero.device
|
||||
|
||||
def muon_parameters(self):
|
||||
return self.transformer.muon_parameters()
|
||||
|
||||
def get_times_from_signal_level(
|
||||
self,
|
||||
signal_levels,
|
||||
|
||||
@ -3,6 +3,8 @@ from torch.nn import Module
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
from adam_atan2_pytorch import MuonAdamAtan2
|
||||
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsModel
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.52"
|
||||
version = "0.0.53"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user