diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 880bdb6..83e8658 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -66,7 +66,7 @@ LinearNoBias = partial(Linear, bias = False) TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips')) -WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'reward', 'behavior_clone')) +WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions')) @dataclass class Experience: @@ -210,6 +210,48 @@ def create_multi_token_prediction_targets( # loss related +class LossNormalizer(Module): + + # the authors mentioned the need for loss normalization in the dynamics transformer + + def __init__( + self, + num_losses: int, + beta = 0.95, + eps = 1e-6 + ): + super().__init__() + self.register_buffer('exp_avg_sq', torch.ones(num_losses)) + self.beta = beta + self.eps = eps + + def forward( + self, + losses: Tensor | list[Tensor] | dict[str, Tensor], + update_ema = None + ): + exp_avg_sq, beta = self.exp_avg_sq, self.beta + update_ema = default(update_ema, self.training) + + # get the rms value - as mentioned at the end of section 3 in the paper + + rms = exp_avg_sq.sqrt() + + if update_ema: + decay = 1. - beta + + # update the ema + + exp_avg_sq.lerp_(losses.detach().square(), decay) + + # then normalize + + assert losses.numel() == rms.numel() + + normed_losses = losses / rms.clamp(min = self.eps) + + return normed_losses + class LPIPSLoss(Module): def __init__( self, @@ -1570,11 +1612,13 @@ class DynamicsWorldModel(Module): num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, continuous_norm_stats = None, - reward_loss_weight = 0.1, + reward_loss_weight = 1., multi_token_pred_len = 8, value_head_mlp_depth = 3, policy_head_mlp_depth = 3, - behavior_clone_weight = 0.1, + latent_flow_loss_weight = 1., + discrete_action_loss_weight = 1., + continuous_action_loss_weight = 1., num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 num_residual_streams = 1, gae_discount_factor = 0.997, @@ -1699,7 +1743,8 @@ class DynamicsWorldModel(Module): squeeze_unembed_preds = False ) - self.behavior_clone_weight = behavior_clone_weight + self.discrete_action_loss_weight = discrete_action_loss_weight + self.continuous_action_loss_weight = continuous_action_loss_weight # multi token prediction length @@ -1726,8 +1771,6 @@ class DynamicsWorldModel(Module): multi_token_pred_len ) - self.reward_loss_weight = reward_loss_weight - # value head self.value_head = create_mlp( @@ -1761,7 +1804,15 @@ class DynamicsWorldModel(Module): self.value_clip = value_clip self.policy_entropy_weight = value_clip - # zero + # loss related + + self.flow_loss_normalizer = LossNormalizer(1) + self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len) + self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None + self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None + + self.latent_flow_loss_weight = latent_flow_loss_weight + self.reward_loss_weight = reward_loss_weight self.register_buffer('zero', tensor(0.), persistent = False) @@ -2461,11 +2512,12 @@ class DynamicsWorldModel(Module): reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.) - reward_loss = reward_losses.sum(dim = -1).mean() # they sum across the prediction steps - eq(9) + reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9) # maybe autoregressive action loss - behavior_clone_loss = self.zero + discrete_action_loss = self.zero + continuous_action_loss = self.zero if ( self.num_agents == 1 and @@ -2503,26 +2555,41 @@ class DynamicsWorldModel(Module): if exists(discrete_log_probs): discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.) - behavior_clone_loss = behavior_clone_loss - reduce(discrete_log_probs, 'mtp b t na -> b t', 'sum').mean() + discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean') if exists(continuous_log_probs): continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.) - behavior_clone_loss = behavior_clone_loss - reduce(continuous_log_probs, 'mtp b t na -> b t', 'sum').mean() + continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean') - # gather losses + # handle loss normalization + + losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss) + + if exists(self.flow_loss_normalizer): + flow_loss = self.flow_loss_normalizer(flow_loss) + + if exists(rewards) and exists(self.reward_loss_normalizer): + reward_loss = self.reward_loss_normalizer(reward_loss) + + if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer): + discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss) + + if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer): + continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss) + + # gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9) total_loss = ( - flow_loss + - reward_loss * self.reward_loss_weight + - behavior_clone_loss * self.behavior_clone_weight + flow_loss * self.latent_flow_loss_weight + + reward_loss.sum() * self.reward_loss_weight + + discrete_action_loss.sum() * self.discrete_action_loss_weight + + continuous_action_loss.sum() * self.continuous_action_loss_weight ) if not return_all_losses: return total_loss - losses = WorldModelLosses(flow_loss, reward_loss, behavior_clone_loss) - return total_loss, losses # dreamer diff --git a/pyproject.toml b/pyproject.toml index 0b9f2af..0d9f1d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.40" +version = "0.0.41" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index e1c996c..b13f49f 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -423,3 +423,16 @@ def test_mtp(): discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0) assert discrete_logits.shape == (2, 16, 8) + +def test_loss_normalizer(): + from torch import cat + from dreamer4.dreamer4 import LossNormalizer + + loss_normalizer = LossNormalizer(4, beta = 0.) + + losses = torch.rand(4) + + _ = loss_normalizer(losses) + normed_losses = loss_normalizer(losses) + + assert (normed_losses == 1.).all()