take care of the loss normalization mentioned at the end of the first paragraph of section 3

This commit is contained in:
lucidrains 2025-10-19 08:24:41 -07:00
parent 79a1b1c46e
commit 374667d8a9
3 changed files with 98 additions and 18 deletions

View File

@ -66,7 +66,7 @@ LinearNoBias = partial(Linear, bias = False)
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips'))
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone')) WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
@dataclass @dataclass
class Experience: class Experience:
@ -210,6 +210,48 @@ def create_multi_token_prediction_targets(
# loss related # loss related
class LossNormalizer(Module):
# the authors mentioned the need for loss normalization in the dynamics transformer
def __init__(
self,
num_losses: int,
beta = 0.95,
eps = 1e-6
):
super().__init__()
self.register_buffer('exp_avg_sq', torch.ones(num_losses))
self.beta = beta
self.eps = eps
def forward(
self,
losses: Tensor | list[Tensor] | dict[str, Tensor],
update_ema = None
):
exp_avg_sq, beta = self.exp_avg_sq, self.beta
update_ema = default(update_ema, self.training)
# get the rms value - as mentioned at the end of section 3 in the paper
rms = exp_avg_sq.sqrt()
if update_ema:
decay = 1. - beta
# update the ema
exp_avg_sq.lerp_(losses.detach().square(), decay)
# then normalize
assert losses.numel() == rms.numel()
normed_losses = losses / rms.clamp(min = self.eps)
return normed_losses
class LPIPSLoss(Module): class LPIPSLoss(Module):
def __init__( def __init__(
self, self,
@ -1570,11 +1612,13 @@ class DynamicsWorldModel(Module):
num_discrete_actions: int | tuple[int, ...] = 0, num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0, num_continuous_actions = 0,
continuous_norm_stats = None, continuous_norm_stats = None,
reward_loss_weight = 0.1, reward_loss_weight = 1.,
multi_token_pred_len = 8, multi_token_pred_len = 8,
value_head_mlp_depth = 3, value_head_mlp_depth = 3,
policy_head_mlp_depth = 3, policy_head_mlp_depth = 3,
behavior_clone_weight = 0.1, latent_flow_loss_weight = 1.,
discrete_action_loss_weight = 1.,
continuous_action_loss_weight = 1.,
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1, num_residual_streams = 1,
gae_discount_factor = 0.997, gae_discount_factor = 0.997,
@ -1699,7 +1743,8 @@ class DynamicsWorldModel(Module):
squeeze_unembed_preds = False squeeze_unembed_preds = False
) )
self.behavior_clone_weight = behavior_clone_weight self.discrete_action_loss_weight = discrete_action_loss_weight
self.continuous_action_loss_weight = continuous_action_loss_weight
# multi token prediction length # multi token prediction length
@ -1726,8 +1771,6 @@ class DynamicsWorldModel(Module):
multi_token_pred_len multi_token_pred_len
) )
self.reward_loss_weight = reward_loss_weight
# value head # value head
self.value_head = create_mlp( self.value_head = create_mlp(
@ -1761,7 +1804,15 @@ class DynamicsWorldModel(Module):
self.value_clip = value_clip self.value_clip = value_clip
self.policy_entropy_weight = value_clip self.policy_entropy_weight = value_clip
# zero # loss related
self.flow_loss_normalizer = LossNormalizer(1)
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
self.latent_flow_loss_weight = latent_flow_loss_weight
self.reward_loss_weight = reward_loss_weight
self.register_buffer('zero', tensor(0.), persistent = False) self.register_buffer('zero', tensor(0.), persistent = False)
@ -2461,11 +2512,12 @@ class DynamicsWorldModel(Module):
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.) reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9) reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
# maybe autoregressive action loss # maybe autoregressive action loss
behavior_clone_loss = self.zero discrete_action_loss = self.zero
continuous_action_loss = self.zero
if ( if (
self.num_agents == 1 and self.num_agents == 1 and
@ -2503,26 +2555,41 @@ class DynamicsWorldModel(Module):
if exists(discrete_log_probs): if exists(discrete_log_probs):
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.) discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean() discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
if exists(continuous_log_probs): if exists(continuous_log_probs):
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.) continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean() continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
# gather losses # handle loss normalization
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
if exists(self.flow_loss_normalizer):
flow_loss = self.flow_loss_normalizer(flow_loss)
if exists(rewards) and exists(self.reward_loss_normalizer):
reward_loss = self.reward_loss_normalizer(reward_loss)
if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer):
discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss)
if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer):
continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss)
# gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9)
total_loss = ( total_loss = (
flow_loss + flow_loss * self.latent_flow_loss_weight +
reward_loss * self.reward_loss_weight + reward_loss.sum() * self.reward_loss_weight +
behavior_clone_loss * self.behavior_clone_weight discrete_action_loss.sum() * self.discrete_action_loss_weight +
continuous_action_loss.sum() * self.continuous_action_loss_weight
) )
if not return_all_losses: if not return_all_losses:
return total_loss return total_loss
losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss)
return total_loss, losses return total_loss, losses
# dreamer # dreamer

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.40" version = "0.0.41"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -423,3 +423,16 @@ def test_mtp():
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0) discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
assert discrete_logits.shape == (2, 16, 8) assert discrete_logits.shape == (2, 16, 8)
def test_loss_normalizer():
from torch import cat
from dreamer4.dreamer4 import LossNormalizer
loss_normalizer = LossNormalizer(4, beta = 0.)
losses = torch.rand(4)
_ = loss_normalizer(losses)
normed_losses = loss_normalizer(losses)
assert (normed_losses == 1.).all()