able to control the update of the loss ema from dynamics model forward

This commit is contained in:
lucidrains 2025-10-19 08:25:50 -07:00
parent 374667d8a9
commit f651d779e3
2 changed files with 7 additions and 6 deletions

View File

@ -2171,7 +2171,8 @@ class DynamicsWorldModel(Module):
latent_is_noised = False,
return_all_losses = False,
return_agent_tokens = False,
add_autoregressive_action_loss = False
add_autoregressive_action_loss = False,
update_loss_ema = None
):
# handle video or latents
@ -2567,16 +2568,16 @@ class DynamicsWorldModel(Module):
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
if exists(self.flow_loss_normalizer):
flow_loss = self.flow_loss_normalizer(flow_loss)
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
if exists(rewards) and exists(self.reward_loss_normalizer):
reward_loss = self.reward_loss_normalizer(reward_loss)
reward_loss = self.reward_loss_normalizer(reward_loss, update_ema = update_loss_ema)
if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss)
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss, update_ema = update_loss_ema)
if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss)
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss, update_ema = update_loss_ema)
# gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.41"
version = "0.0.42"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }