allow for setting different loss weights for each MTP head (perhaps more weight on the next vs some far out prediction)

This commit is contained in:
lucidrains 2025-10-19 08:37:56 -07:00
parent f651d779e3
commit ecbe13efe8
2 changed files with 15 additions and 11 deletions

View File

@ -1612,13 +1612,13 @@ class DynamicsWorldModel(Module):
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
continuous_norm_stats = None,
reward_loss_weight = 1.,
multi_token_pred_len = 8,
value_head_mlp_depth = 3,
policy_head_mlp_depth = 3,
latent_flow_loss_weight = 1.,
discrete_action_loss_weight = 1.,
continuous_action_loss_weight = 1.,
reward_loss_weight: float | list[float] = 1.,
discrete_action_loss_weight: float | list[float] = 1.,
continuous_action_loss_weight: float | list[float] = 1.,
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1,
gae_discount_factor = 0.997,
@ -1743,9 +1743,6 @@ class DynamicsWorldModel(Module):
squeeze_unembed_preds = False
)
self.discrete_action_loss_weight = discrete_action_loss_weight
self.continuous_action_loss_weight = continuous_action_loss_weight
# multi token prediction length
self.multi_token_pred_len = multi_token_pred_len
@ -1812,7 +1809,14 @@ class DynamicsWorldModel(Module):
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
self.latent_flow_loss_weight = latent_flow_loss_weight
self.reward_loss_weight = reward_loss_weight
self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
self.register_buffer('zero', tensor(0.), persistent = False)
@ -2583,9 +2587,9 @@ class DynamicsWorldModel(Module):
total_loss = (
flow_loss * self.latent_flow_loss_weight +
reward_loss.sum() * self.reward_loss_weight +
discrete_action_loss.sum() * self.discrete_action_loss_weight +
continuous_action_loss.sum() * self.continuous_action_loss_weight
(reward_loss * self.reward_loss_weight).sum() +
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
(continuous_action_loss * self.continuous_action_loss_weight).sum()
)
if not return_all_losses:

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.42"
version = "0.0.43"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }