fix pmpo
This commit is contained in:
parent
7acaa764f6
commit
91d697f8ca
@ -186,6 +186,15 @@ def lens_to_mask(t, max_len = None):
|
||||
|
||||
return einx.less('j, i -> i j', seq, t)
|
||||
|
||||
def masked_mean(t, mask = None):
|
||||
if not exists(mask):
|
||||
return t.mean()
|
||||
|
||||
if not mask.any():
|
||||
return t[mask].sum()
|
||||
|
||||
return t[mask].mean()
|
||||
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
@ -1824,6 +1833,7 @@ class DynamicsWorldModel(Module):
|
||||
gae_discount_factor = 0.997,
|
||||
gae_lambda = 0.95,
|
||||
ppo_eps_clip = 0.2,
|
||||
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
||||
value_clip = 0.4,
|
||||
policy_entropy_weight = .01,
|
||||
gae_use_accelerated = False
|
||||
@ -2027,6 +2037,10 @@ class DynamicsWorldModel(Module):
|
||||
self.value_clip = value_clip
|
||||
self.policy_entropy_weight = value_clip
|
||||
|
||||
# pmpo related
|
||||
|
||||
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
||||
|
||||
# rewards related
|
||||
|
||||
self.keep_reward_ema_stats = keep_reward_ema_stats
|
||||
@ -2334,7 +2348,7 @@ class DynamicsWorldModel(Module):
|
||||
policy_optim: Optimizer | None = None,
|
||||
value_optim: Optimizer | None = None,
|
||||
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||
use_signed_advantage = True,
|
||||
use_pmpo = True,
|
||||
eps = 1e-6
|
||||
):
|
||||
|
||||
@ -2374,6 +2388,8 @@ class DynamicsWorldModel(Module):
|
||||
max_time = latents.shape[1]
|
||||
is_var_len = exists(experience.lens)
|
||||
|
||||
mask = None
|
||||
|
||||
if is_var_len:
|
||||
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
||||
mask = lens_to_mask(learnable_lens, max_time)
|
||||
@ -2417,8 +2433,9 @@ class DynamicsWorldModel(Module):
|
||||
# apparently they just use the sign of the advantage
|
||||
# https://arxiv.org/abs/2410.04166v1
|
||||
|
||||
if use_signed_advantage:
|
||||
advantage = advantage.sign()
|
||||
if use_pmpo:
|
||||
pos_advantage_mask = advantage >= 0.
|
||||
neg_advantage_mask = ~pos_advantage_mask
|
||||
else:
|
||||
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
||||
|
||||
@ -2464,35 +2481,48 @@ class DynamicsWorldModel(Module):
|
||||
log_probs = safe_cat(log_probs, dim = -1)
|
||||
entropies = safe_cat(entropies, dim = -1)
|
||||
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
|
||||
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
||||
|
||||
advantage = rearrange(advantage, '... -> ... 1') # broadcast across all actions
|
||||
|
||||
# clipped surrogate loss
|
||||
if use_pmpo:
|
||||
# pmpo - weighting the positive and negative advantages equally - ignoring magnitude of advantage and taking the sign
|
||||
# seems to be weighted across batch and time, iiuc
|
||||
# eq (10) in https://arxiv.org/html/2410.04166v1
|
||||
|
||||
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
||||
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
||||
if exists(mask):
|
||||
pos_advantage_mask &= mask
|
||||
neg_advantage_mask &= mask
|
||||
|
||||
policy_loss = policy_loss.mean()
|
||||
α = self.pmpo_pos_to_neg_weight
|
||||
|
||||
pos = masked_mean(log_probs, pos_advantage_mask)
|
||||
neg = -masked_mean(log_probs, neg_advantage_mask)
|
||||
|
||||
policy_loss = -(α * pos + (1. - α) * neg)
|
||||
|
||||
else:
|
||||
# ppo clipped surrogate loss
|
||||
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
clipped_ratio = ratio.clamp(1. - self.ppo_eps_clip, 1. + self.ppo_eps_clip)
|
||||
|
||||
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
||||
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
||||
|
||||
policy_loss = masked_mean(policy_loss, mask)
|
||||
|
||||
# handle entropy loss for naive exploration bonus
|
||||
|
||||
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
||||
|
||||
entropy_loss = masked_mean(entropy_loss, mask)
|
||||
|
||||
# total policy loss
|
||||
|
||||
total_policy_loss = (
|
||||
policy_loss +
|
||||
entropy_loss * self.policy_entropy_weight
|
||||
)
|
||||
|
||||
# maybe handle variable lengths
|
||||
|
||||
if is_var_len:
|
||||
total_policy_loss = total_policy_loss[mask].mean()
|
||||
else:
|
||||
total_policy_loss = total_policy_loss.mean()
|
||||
|
||||
# maybe take policy optimizer step
|
||||
|
||||
if exists(policy_optim):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.90"
|
||||
version = "0.0.91"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -611,13 +611,13 @@ def test_cache_generate():
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
|
||||
@param('vectorized', (False, True))
|
||||
@param('use_signed_advantage', (False, True))
|
||||
@param('use_pmpo', (False, True))
|
||||
@param('env_can_terminate', (False, True))
|
||||
@param('env_can_truncate', (False, True))
|
||||
@param('store_agent_embed', (False, True))
|
||||
def test_online_rl(
|
||||
vectorized,
|
||||
use_signed_advantage,
|
||||
use_pmpo,
|
||||
env_can_terminate,
|
||||
env_can_truncate,
|
||||
store_agent_embed
|
||||
@ -674,7 +674,7 @@ def test_online_rl(
|
||||
if store_agent_embed:
|
||||
assert exists(combined_experience.agent_embed)
|
||||
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
|
||||
|
||||
actor_loss.backward()
|
||||
critic_loss.backward()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user