take care of storing the old action logits and mean log var, and calculate kl div for pmpo based off that during learn from experience
This commit is contained in:
parent
691d9ca007
commit
3547344312
@ -81,6 +81,7 @@ class Experience:
|
||||
rewards: Tensor | None = None
|
||||
actions: tuple[Tensor, Tensor] | None = None
|
||||
log_probs: tuple[Tensor, Tensor] | None = None
|
||||
old_action_unembeds: tuple[Tensor, Tensor] | None = None
|
||||
values: Tensor | None = None
|
||||
step_size: int | None = None
|
||||
lens: Tensor | None = None
|
||||
@ -1887,6 +1888,7 @@ class DynamicsWorldModel(Module):
|
||||
gae_lambda = 0.95,
|
||||
ppo_eps_clip = 0.2,
|
||||
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
||||
pmpo_kl_div_loss_weight = 1.,
|
||||
value_clip = 0.4,
|
||||
policy_entropy_weight = .01,
|
||||
gae_use_accelerated = False
|
||||
@ -2093,6 +2095,7 @@ class DynamicsWorldModel(Module):
|
||||
# pmpo related
|
||||
|
||||
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
||||
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
||||
|
||||
# rewards related
|
||||
|
||||
@ -2222,7 +2225,8 @@ class DynamicsWorldModel(Module):
|
||||
max_timesteps = 16,
|
||||
env_is_vectorized = False,
|
||||
use_time_kv_cache = True,
|
||||
store_agent_embed = False
|
||||
store_agent_embed = False,
|
||||
store_old_action_unembeds = False,
|
||||
):
|
||||
assert exists(self.video_tokenizer)
|
||||
|
||||
@ -2248,6 +2252,7 @@ class DynamicsWorldModel(Module):
|
||||
latents = None
|
||||
|
||||
acc_agent_embed = None
|
||||
acc_policy_embed = None
|
||||
|
||||
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
||||
|
||||
@ -2300,6 +2305,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
policy_embed = self.policy_head(one_agent_embed)
|
||||
|
||||
if store_old_action_unembeds:
|
||||
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
||||
|
||||
# sample actions
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||
@ -2383,6 +2391,7 @@ class DynamicsWorldModel(Module):
|
||||
actions = (discrete_actions, continuous_actions),
|
||||
log_probs = (discrete_log_probs, continuous_log_probs),
|
||||
values = values,
|
||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
@ -2411,6 +2420,7 @@ class DynamicsWorldModel(Module):
|
||||
old_values = experience.values
|
||||
rewards = experience.rewards
|
||||
agent_embeds = experience.agent_embed
|
||||
old_action_unembeds = experience.old_action_unembeds
|
||||
|
||||
step_size = experience.step_size
|
||||
agent_index = experience.agent_index
|
||||
@ -2489,6 +2499,7 @@ class DynamicsWorldModel(Module):
|
||||
if use_pmpo:
|
||||
pos_advantage_mask = advantage >= 0.
|
||||
neg_advantage_mask = ~pos_advantage_mask
|
||||
|
||||
else:
|
||||
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
||||
|
||||
@ -2552,6 +2563,25 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
policy_loss = -(α * pos + (1. - α) * neg)
|
||||
|
||||
# take care of kl
|
||||
|
||||
if self.pmpo_kl_div_loss_weight > 0.:
|
||||
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
||||
|
||||
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
|
||||
|
||||
# accumulate discrete and continuous kl div
|
||||
|
||||
kl_div_loss = 0.
|
||||
|
||||
if exists(discrete_kl_div):
|
||||
kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
|
||||
|
||||
if exists(continuous_kl_div):
|
||||
kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
|
||||
|
||||
policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
|
||||
|
||||
else:
|
||||
# ppo clipped surrogate loss
|
||||
|
||||
@ -2694,6 +2724,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
acc_agent_embed = None
|
||||
|
||||
# maybe store old actions for kl
|
||||
|
||||
acc_policy_embed = None
|
||||
|
||||
# maybe return rewards
|
||||
|
||||
decoded_rewards = None
|
||||
@ -2818,6 +2852,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
policy_embed = self.policy_head(one_agent_embed)
|
||||
|
||||
# maybe store old actions
|
||||
|
||||
if store_old_action_unembeds:
|
||||
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
|
||||
|
||||
# sample actions
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||
|
||||
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
||||
@ -2906,6 +2947,7 @@ class DynamicsWorldModel(Module):
|
||||
video = video,
|
||||
proprio = proprio if has_proprio else None,
|
||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
lens = experience_lens,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.92"
|
||||
version = "0.0.93"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user