allow for setting different loss weights for each MTP head (perhaps more weight on the next vs some far out prediction)
This commit is contained in:
parent
f651d779e3
commit
ecbe13efe8
@ -1612,13 +1612,13 @@ class DynamicsWorldModel(Module):
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
reward_loss_weight = 1.,
|
||||
multi_token_pred_len = 8,
|
||||
value_head_mlp_depth = 3,
|
||||
policy_head_mlp_depth = 3,
|
||||
latent_flow_loss_weight = 1.,
|
||||
discrete_action_loss_weight = 1.,
|
||||
continuous_action_loss_weight = 1.,
|
||||
reward_loss_weight: float | list[float] = 1.,
|
||||
discrete_action_loss_weight: float | list[float] = 1.,
|
||||
continuous_action_loss_weight: float | list[float] = 1.,
|
||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
num_residual_streams = 1,
|
||||
gae_discount_factor = 0.997,
|
||||
@ -1743,9 +1743,6 @@ class DynamicsWorldModel(Module):
|
||||
squeeze_unembed_preds = False
|
||||
)
|
||||
|
||||
self.discrete_action_loss_weight = discrete_action_loss_weight
|
||||
self.continuous_action_loss_weight = continuous_action_loss_weight
|
||||
|
||||
# multi token prediction length
|
||||
|
||||
self.multi_token_pred_len = multi_token_pred_len
|
||||
@ -1812,7 +1809,14 @@ class DynamicsWorldModel(Module):
|
||||
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||
|
||||
self.latent_flow_loss_weight = latent_flow_loss_weight
|
||||
self.reward_loss_weight = reward_loss_weight
|
||||
|
||||
self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
|
||||
self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
|
||||
self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
|
||||
|
||||
assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
|
||||
assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
|
||||
assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
|
||||
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
|
||||
@ -2583,9 +2587,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
total_loss = (
|
||||
flow_loss * self.latent_flow_loss_weight +
|
||||
reward_loss.sum() * self.reward_loss_weight +
|
||||
discrete_action_loss.sum() * self.discrete_action_loss_weight +
|
||||
continuous_action_loss.sum() * self.continuous_action_loss_weight
|
||||
(reward_loss * self.reward_loss_weight).sum() +
|
||||
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
||||
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
||||
)
|
||||
|
||||
if not return_all_losses:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.42"
|
||||
version = "0.0.43"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user