From 91d697f8ca270edf1a309300d58abc7f0160b363 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 28 Oct 2025 18:55:22 -0700 Subject: [PATCH] fix pmpo --- dreamer4/dreamer4.py | 66 +++++++++++++++++++++++++++++++------------ pyproject.toml | 2 +- tests/test_dreamer.py | 6 ++-- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index e3d16da..228edc1 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -186,6 +186,15 @@ def lens_to_mask(t, max_len = None): return einx.less('j, i -> i j', seq, t) +def masked_mean(t, mask = None): + if not exists(mask): + return t.mean() + + if not mask.any(): + return t[mask].sum() + + return t[mask].mean() + def log(t, eps = 1e-20): return t.clamp(min = eps).log() @@ -1824,6 +1833,7 @@ class DynamicsWorldModel(Module): gae_discount_factor = 0.997, gae_lambda = 0.95, ppo_eps_clip = 0.2, + pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight value_clip = 0.4, policy_entropy_weight = .01, gae_use_accelerated = False @@ -2027,6 +2037,10 @@ class DynamicsWorldModel(Module): self.value_clip = value_clip self.policy_entropy_weight = value_clip + # pmpo related + + self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight + # rewards related self.keep_reward_ema_stats = keep_reward_ema_stats @@ -2334,7 +2348,7 @@ class DynamicsWorldModel(Module): policy_optim: Optimizer | None = None, value_optim: Optimizer | None = None, only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads - use_signed_advantage = True, + use_pmpo = True, eps = 1e-6 ): @@ -2374,6 +2388,8 @@ class DynamicsWorldModel(Module): max_time = latents.shape[1] is_var_len = exists(experience.lens) + mask = None + if is_var_len: learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value mask = lens_to_mask(learnable_lens, max_time) @@ -2417,8 +2433,9 @@ class DynamicsWorldModel(Module): # apparently they just use the sign of the advantage # https://arxiv.org/abs/2410.04166v1 - if use_signed_advantage: - advantage = advantage.sign() + if use_pmpo: + pos_advantage_mask = advantage >= 0. + neg_advantage_mask = ~pos_advantage_mask else: advantage = F.layer_norm(advantage, advantage.shape, eps = eps) @@ -2464,35 +2481,48 @@ class DynamicsWorldModel(Module): log_probs = safe_cat(log_probs, dim = -1) entropies = safe_cat(entropies, dim = -1) - ratio = (log_probs - old_log_probs).exp() - - clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip) - advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions - # clipped surrogate loss + if use_pmpo: + # pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign + # seems to be weighted across batch and time, iiuc + # eq (10) in https://arxiv.org/html/2410.04166v1 - policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage) - policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum') + if exists(mask): + pos_advantage_mask &= mask + neg_advantage_mask &= mask - policy_loss = policy_loss.mean() + α = self.pmpo_pos_to_neg_weight + + pos = masked_mean(log_probs, pos_advantage_mask) + neg = -masked_mean(log_probs, neg_advantage_mask) + + policy_loss = -(α * pos + (1. - α) * neg) + + else: + # ppo clipped surrogate loss + + ratio = (log_probs - old_log_probs).exp() + clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip) + + policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage) + policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum') + + policy_loss = masked_mean(policy_loss, mask) # handle entropy loss for naive exploration bonus entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum') + entropy_loss = masked_mean(entropy_loss, mask) + + # total policy loss + total_policy_loss = ( policy_loss + entropy_loss * self.policy_entropy_weight ) - # maybe handle variable lengths - - if is_var_len: - total_policy_loss = total_policy_loss[mask].mean() - else: - total_policy_loss = total_policy_loss.mean() - # maybe take policy optimizer step if exists(policy_optim): diff --git a/pyproject.toml b/pyproject.toml index a8d2968..48dfec2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.90" +version = "0.0.91" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 9fb9d12..85bd78b 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -611,13 +611,13 @@ def test_cache_generate(): generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) @param('vectorized', (False, True)) -@param('use_signed_advantage', (False, True)) +@param('use_pmpo', (False, True)) @param('env_can_terminate', (False, True)) @param('env_can_truncate', (False, True)) @param('store_agent_embed', (False, True)) def test_online_rl( vectorized, - use_signed_advantage, + use_pmpo, env_can_terminate, env_can_truncate, store_agent_embed @@ -674,7 +674,7 @@ def test_online_rl( if store_agent_embed: assert exists(combined_experience.agent_embed) - actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage) + actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo) actor_loss.backward() critic_loss.backward()