take care of storing the old action logits and mean log var, and calculate kl div for pmpo based off that during learn from experience

This commit is contained in:
lucidrains 2025-10-29 10:31:32 -07:00
parent 691d9ca007
commit 3547344312
2 changed files with 44 additions and 2 deletions

View File

@ -81,6 +81,7 @@ class Experience:
rewards: Tensor | None = None
actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None
old_action_unembeds: tuple[Tensor, Tensor] | None = None
values: Tensor | None = None
step_size: int | None = None
lens: Tensor | None = None
@ -1887,6 +1888,7 @@ class DynamicsWorldModel(Module):
gae_lambda = 0.95,
ppo_eps_clip = 0.2,
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
pmpo_kl_div_loss_weight = 1.,
value_clip = 0.4,
policy_entropy_weight = .01,
gae_use_accelerated = False
@ -2093,6 +2095,7 @@ class DynamicsWorldModel(Module):
# pmpo related
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
# rewards related
@ -2222,7 +2225,8 @@ class DynamicsWorldModel(Module):
max_timesteps = 16,
env_is_vectorized = False,
use_time_kv_cache = True,
store_agent_embed = False
store_agent_embed = False,
store_old_action_unembeds = False,
):
assert exists(self.video_tokenizer)
@ -2248,6 +2252,7 @@ class DynamicsWorldModel(Module):
latents = None
acc_agent_embed = None
acc_policy_embed = None
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
@ -2300,6 +2305,9 @@ class DynamicsWorldModel(Module):
policy_embed = self.policy_head(one_agent_embed)
if store_old_action_unembeds:
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
# sample actions
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
@ -2383,6 +2391,7 @@ class DynamicsWorldModel(Module):
actions = (discrete_actions, continuous_actions),
log_probs = (discrete_log_probs, continuous_log_probs),
values = values,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
agent_embed = acc_agent_embed if store_agent_embed else None,
step_size = step_size,
agent_index = agent_index,
@ -2411,6 +2420,7 @@ class DynamicsWorldModel(Module):
old_values = experience.values
rewards = experience.rewards
agent_embeds = experience.agent_embed
old_action_unembeds = experience.old_action_unembeds
step_size = experience.step_size
agent_index = experience.agent_index
@ -2489,6 +2499,7 @@ class DynamicsWorldModel(Module):
if use_pmpo:
pos_advantage_mask = advantage >= 0.
neg_advantage_mask = ~pos_advantage_mask
else:
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
@ -2552,6 +2563,25 @@ class DynamicsWorldModel(Module):
policy_loss = -(α * pos + (1. - α) * neg)
# take care of kl
if self.pmpo_kl_div_loss_weight > 0.:
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
# accumulate discrete and continuous kl div
kl_div_loss = 0.
if exists(discrete_kl_div):
kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
if exists(continuous_kl_div):
kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
else:
# ppo clipped surrogate loss
@ -2694,6 +2724,10 @@ class DynamicsWorldModel(Module):
acc_agent_embed = None
# maybe store old actions for kl
acc_policy_embed = None
# maybe return rewards
decoded_rewards = None
@ -2818,6 +2852,13 @@ class DynamicsWorldModel(Module):
policy_embed = self.policy_head(one_agent_embed)
# maybe store old actions
if store_old_action_unembeds:
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
# sample actions
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
@ -2906,6 +2947,7 @@ class DynamicsWorldModel(Module):
video = video,
proprio = proprio if has_proprio else None,
agent_embed = acc_agent_embed if store_agent_embed else None,
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
step_size = step_size,
agent_index = agent_index,
lens = experience_lens,

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.92"
version = "0.0.93"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }