This commit is contained in:
lucidrains 2025-10-28 18:55:22 -07:00
parent 7acaa764f6
commit 91d697f8ca
3 changed files with 52 additions and 22 deletions

View File

@ -186,6 +186,15 @@ def lens_to_mask(t, max_len = None):
return einx.less('j, i -> i j', seq, t)
def masked_mean(t, mask = None):
if not exists(mask):
return t.mean()
if not mask.any():
return t[mask].sum()
return t[mask].mean()
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
@ -1824,6 +1833,7 @@ class DynamicsWorldModel(Module):
gae_discount_factor = 0.997,
gae_lambda = 0.95,
ppo_eps_clip = 0.2,
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
value_clip = 0.4,
policy_entropy_weight = .01,
gae_use_accelerated = False
@ -2027,6 +2037,10 @@ class DynamicsWorldModel(Module):
self.value_clip = value_clip
self.policy_entropy_weight = value_clip
# pmpo related
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
# rewards related
self.keep_reward_ema_stats = keep_reward_ema_stats
@ -2334,7 +2348,7 @@ class DynamicsWorldModel(Module):
policy_optim: Optimizer | None = None,
value_optim: Optimizer | None = None,
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
use_signed_advantage = True,
use_pmpo = True,
eps = 1e-6
):
@ -2374,6 +2388,8 @@ class DynamicsWorldModel(Module):
max_time = latents.shape[1]
is_var_len = exists(experience.lens)
mask = None
if is_var_len:
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
mask = lens_to_mask(learnable_lens, max_time)
@ -2417,8 +2433,9 @@ class DynamicsWorldModel(Module):
# apparently they just use the sign of the advantage
# https://arxiv.org/abs/2410.04166v1
if use_signed_advantage:
advantage = advantage.sign()
if use_pmpo:
pos_advantage_mask = advantage >= 0.
neg_advantage_mask = ~pos_advantage_mask
else:
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
@ -2464,35 +2481,48 @@ class DynamicsWorldModel(Module):
log_probs = safe_cat(log_probs, dim = -1)
entropies = safe_cat(entropies, dim = -1)
ratio = (log_probs - old_log_probs).exp()
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
# clipped surrogate loss
if use_pmpo:
# pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
# seems to be weighted across batch and time, iiuc
# eq (10) in https://arxiv.org/html/2410.04166v1
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
if exists(mask):
pos_advantage_mask &= mask
neg_advantage_mask &= mask
policy_loss = policy_loss.mean()
α = self.pmpo_pos_to_neg_weight
pos = masked_mean(log_probs, pos_advantage_mask)
neg = -masked_mean(log_probs, neg_advantage_mask)
policy_loss = -(α * pos + (1. - α) * neg)
else:
# ppo clipped surrogate loss
ratio = (log_probs - old_log_probs).exp()
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
policy_loss = masked_mean(policy_loss, mask)
# handle entropy loss for naive exploration bonus
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
entropy_loss = masked_mean(entropy_loss, mask)
# total policy loss
total_policy_loss = (
policy_loss +
entropy_loss * self.policy_entropy_weight
)
# maybe handle variable lengths
if is_var_len:
total_policy_loss = total_policy_loss[mask].mean()
else:
total_policy_loss = total_policy_loss.mean()
# maybe take policy optimizer step
if exists(policy_optim):

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.90"
version = "0.0.91"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -611,13 +611,13 @@ def test_cache_generate():
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
@param('vectorized', (False, True))
@param('use_signed_advantage', (False, True))
@param('use_pmpo', (False, True))
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl(
vectorized,
use_signed_advantage,
use_pmpo,
env_can_terminate,
env_can_truncate,
store_agent_embed
@ -674,7 +674,7 @@ def test_online_rl(
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
actor_loss.backward()
critic_loss.backward()