more muon prep
This commit is contained in:
parent
b4763caff9
commit
15876d34cf
@ -1132,6 +1132,8 @@ class Attention(Module):
|
|||||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) if rmsnorm_key else nn.Identity()
|
||||||
|
|
||||||
def muon_parameters(self):
|
def muon_parameters(self):
|
||||||
|
# omit the queries and keys for now given what we learned from kimi 2 paper
|
||||||
|
|
||||||
return [
|
return [
|
||||||
*self.to_v.parameters(),
|
*self.to_v.parameters(),
|
||||||
*self.to_out.parameters(),
|
*self.to_out.parameters(),
|
||||||
@ -1298,6 +1300,15 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
self.num_special_spatial_tokens = num_special_spatial_tokens
|
self.num_special_spatial_tokens = num_special_spatial_tokens
|
||||||
|
|
||||||
|
def muon_parameters(self):
|
||||||
|
muon_params = []
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, (Attention, SwiGLUFeedforward)):
|
||||||
|
muon_params.extend(m.muon_parameters())
|
||||||
|
|
||||||
|
return muon_params
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
tokens, # (b t s d)
|
tokens, # (b t s d)
|
||||||
@ -1523,6 +1534,12 @@ class VideoTokenizer(Module):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self.zero.device
|
return self.zero.device
|
||||||
|
|
||||||
|
def muon_parameters(self):
|
||||||
|
return [
|
||||||
|
*self.encoder_transformer.muon_parameters(),
|
||||||
|
*self.decoder_transformer.muon_parameters()
|
||||||
|
]
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def tokenize(
|
def tokenize(
|
||||||
self,
|
self,
|
||||||
@ -1906,6 +1923,9 @@ class DynamicsWorldModel(Module):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self.zero.device
|
return self.zero.device
|
||||||
|
|
||||||
|
def muon_parameters(self):
|
||||||
|
return self.transformer.muon_parameters()
|
||||||
|
|
||||||
def get_times_from_signal_level(
|
def get_times_from_signal_level(
|
||||||
self,
|
self,
|
||||||
signal_levels,
|
signal_levels,
|
||||||
|
|||||||
@ -3,6 +3,8 @@ from torch.nn import Module
|
|||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
from adam_atan2_pytorch import MuonAdamAtan2
|
||||||
|
|
||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsModel
|
DynamicsModel
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.52"
|
version = "0.0.53"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user