From ecbe13efe83bcfeab1fc9739baf83c98d81d324b Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 19 Oct 2025 08:37:56 -0700 Subject: [PATCH] allow for setting different loss weights for each MTP head (perhaps more weight on the next vs some far out prediction) --- dreamer4/dreamer4.py | 24 ++++++++++++++---------- pyproject.toml | 2 +- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index fdc7c99..64c62f0 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1612,13 +1612,13 @@ class DynamicsWorldModel(Module): num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, continuous_norm_stats = None, - reward_loss_weight = 1., multi_token_pred_len = 8, value_head_mlp_depth = 3, policy_head_mlp_depth = 3, latent_flow_loss_weight = 1., - discrete_action_loss_weight = 1., - continuous_action_loss_weight = 1., + reward_loss_weight: float | list[float] = 1., + discrete_action_loss_weight: float | list[float] = 1., + continuous_action_loss_weight: float | list[float] = 1., num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 num_residual_streams = 1, gae_discount_factor = 0.997, @@ -1743,9 +1743,6 @@ class DynamicsWorldModel(Module): squeeze_unembed_preds = False ) - self.discrete_action_loss_weight = discrete_action_loss_weight - self.continuous_action_loss_weight = continuous_action_loss_weight - # multi token prediction length self.multi_token_pred_len = multi_token_pred_len @@ -1812,7 +1809,14 @@ class DynamicsWorldModel(Module): self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None self.latent_flow_loss_weight = latent_flow_loss_weight - self.reward_loss_weight = reward_loss_weight + + self.register_buffer('reward_loss_weight', tensor(reward_loss_weight)) + self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight)) + self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight)) + + assert self.reward_loss_weight.numel() in {1, multi_token_pred_len} + assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len} + assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len} self.register_buffer('zero', tensor(0.), persistent = False) @@ -2583,9 +2587,9 @@ class DynamicsWorldModel(Module): total_loss = ( flow_loss * self.latent_flow_loss_weight + - reward_loss.sum() * self.reward_loss_weight + - discrete_action_loss.sum() * self.discrete_action_loss_weight + - continuous_action_loss.sum() * self.continuous_action_loss_weight + (reward_loss * self.reward_loss_weight).sum() + + (discrete_action_loss * self.discrete_action_loss_weight).sum() + + (continuous_action_loss * self.continuous_action_loss_weight).sum() ) if not return_all_losses: diff --git a/pyproject.toml b/pyproject.toml index 27e68f6..38418a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.42" +version = "0.0.43" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }