allow for setting different loss weights for each MTP head (perhaps more weight on the next vs some far out prediction)
This commit is contained in:
parent
f651d779e3
commit
ecbe13efe8
@ -1612,13 +1612,13 @@ class DynamicsWorldModel(Module):
|
|||||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
continuous_norm_stats = None,
|
continuous_norm_stats = None,
|
||||||
reward_loss_weight = 1.,
|
|
||||||
multi_token_pred_len = 8,
|
multi_token_pred_len = 8,
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
policy_head_mlp_depth = 3,
|
policy_head_mlp_depth = 3,
|
||||||
latent_flow_loss_weight = 1.,
|
latent_flow_loss_weight = 1.,
|
||||||
discrete_action_loss_weight = 1.,
|
reward_loss_weight: float | list[float] = 1.,
|
||||||
continuous_action_loss_weight = 1.,
|
discrete_action_loss_weight: float | list[float] = 1.,
|
||||||
|
continuous_action_loss_weight: float | list[float] = 1.,
|
||||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
num_residual_streams = 1,
|
num_residual_streams = 1,
|
||||||
gae_discount_factor = 0.997,
|
gae_discount_factor = 0.997,
|
||||||
@ -1743,9 +1743,6 @@ class DynamicsWorldModel(Module):
|
|||||||
squeeze_unembed_preds = False
|
squeeze_unembed_preds = False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.discrete_action_loss_weight = discrete_action_loss_weight
|
|
||||||
self.continuous_action_loss_weight = continuous_action_loss_weight
|
|
||||||
|
|
||||||
# multi token prediction length
|
# multi token prediction length
|
||||||
|
|
||||||
self.multi_token_pred_len = multi_token_pred_len
|
self.multi_token_pred_len = multi_token_pred_len
|
||||||
@ -1812,7 +1809,14 @@ class DynamicsWorldModel(Module):
|
|||||||
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||||
|
|
||||||
self.latent_flow_loss_weight = latent_flow_loss_weight
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
||||||
self.reward_loss_weight = reward_loss_weight
|
|
||||||
|
self.register_buffer('reward_loss_weight', tensor(reward_loss_weight))
|
||||||
|
self.register_buffer('discrete_action_loss_weight', tensor(discrete_action_loss_weight))
|
||||||
|
self.register_buffer('continuous_action_loss_weight', tensor(continuous_action_loss_weight))
|
||||||
|
|
||||||
|
assert self.reward_loss_weight.numel() in {1, multi_token_pred_len}
|
||||||
|
assert self.discrete_action_loss_weight.numel() in {1, multi_token_pred_len}
|
||||||
|
assert self.continuous_action_loss_weight.numel() in {1, multi_token_pred_len}
|
||||||
|
|
||||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||||
|
|
||||||
@ -2583,9 +2587,9 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
total_loss = (
|
total_loss = (
|
||||||
flow_loss * self.latent_flow_loss_weight +
|
flow_loss * self.latent_flow_loss_weight +
|
||||||
reward_loss.sum() * self.reward_loss_weight +
|
(reward_loss * self.reward_loss_weight).sum() +
|
||||||
discrete_action_loss.sum() * self.discrete_action_loss_weight +
|
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
||||||
continuous_action_loss.sum() * self.continuous_action_loss_weight
|
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_all_losses:
|
if not return_all_losses:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.42"
|
version = "0.0.43"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user