take care of storing the old action logits and mean log var, and calculate kl div for pmpo based off that during learn from experience
This commit is contained in:
parent
691d9ca007
commit
3547344312
@ -81,6 +81,7 @@ class Experience:
|
|||||||
rewards: Tensor | None = None
|
rewards: Tensor | None = None
|
||||||
actions: tuple[Tensor, Tensor] | None = None
|
actions: tuple[Tensor, Tensor] | None = None
|
||||||
log_probs: tuple[Tensor, Tensor] | None = None
|
log_probs: tuple[Tensor, Tensor] | None = None
|
||||||
|
old_action_unembeds: tuple[Tensor, Tensor] | None = None
|
||||||
values: Tensor | None = None
|
values: Tensor | None = None
|
||||||
step_size: int | None = None
|
step_size: int | None = None
|
||||||
lens: Tensor | None = None
|
lens: Tensor | None = None
|
||||||
@ -1887,6 +1888,7 @@ class DynamicsWorldModel(Module):
|
|||||||
gae_lambda = 0.95,
|
gae_lambda = 0.95,
|
||||||
ppo_eps_clip = 0.2,
|
ppo_eps_clip = 0.2,
|
||||||
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
||||||
|
pmpo_kl_div_loss_weight = 1.,
|
||||||
value_clip = 0.4,
|
value_clip = 0.4,
|
||||||
policy_entropy_weight = .01,
|
policy_entropy_weight = .01,
|
||||||
gae_use_accelerated = False
|
gae_use_accelerated = False
|
||||||
@ -2093,6 +2095,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# pmpo related
|
# pmpo related
|
||||||
|
|
||||||
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight
|
||||||
|
self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight
|
||||||
|
|
||||||
# rewards related
|
# rewards related
|
||||||
|
|
||||||
@ -2222,7 +2225,8 @@ class DynamicsWorldModel(Module):
|
|||||||
max_timesteps = 16,
|
max_timesteps = 16,
|
||||||
env_is_vectorized = False,
|
env_is_vectorized = False,
|
||||||
use_time_kv_cache = True,
|
use_time_kv_cache = True,
|
||||||
store_agent_embed = False
|
store_agent_embed = False,
|
||||||
|
store_old_action_unembeds = False,
|
||||||
):
|
):
|
||||||
assert exists(self.video_tokenizer)
|
assert exists(self.video_tokenizer)
|
||||||
|
|
||||||
@ -2248,6 +2252,7 @@ class DynamicsWorldModel(Module):
|
|||||||
latents = None
|
latents = None
|
||||||
|
|
||||||
acc_agent_embed = None
|
acc_agent_embed = None
|
||||||
|
acc_policy_embed = None
|
||||||
|
|
||||||
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
||||||
|
|
||||||
@ -2300,6 +2305,9 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
policy_embed = self.policy_head(one_agent_embed)
|
policy_embed = self.policy_head(one_agent_embed)
|
||||||
|
|
||||||
|
if store_old_action_unembeds:
|
||||||
|
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed), dim = 1)
|
||||||
|
|
||||||
# sample actions
|
# sample actions
|
||||||
|
|
||||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||||
@ -2383,6 +2391,7 @@ class DynamicsWorldModel(Module):
|
|||||||
actions = (discrete_actions, continuous_actions),
|
actions = (discrete_actions, continuous_actions),
|
||||||
log_probs = (discrete_log_probs, continuous_log_probs),
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
||||||
values = values,
|
values = values,
|
||||||
|
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
||||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
agent_index = agent_index,
|
agent_index = agent_index,
|
||||||
@ -2411,6 +2420,7 @@ class DynamicsWorldModel(Module):
|
|||||||
old_values = experience.values
|
old_values = experience.values
|
||||||
rewards = experience.rewards
|
rewards = experience.rewards
|
||||||
agent_embeds = experience.agent_embed
|
agent_embeds = experience.agent_embed
|
||||||
|
old_action_unembeds = experience.old_action_unembeds
|
||||||
|
|
||||||
step_size = experience.step_size
|
step_size = experience.step_size
|
||||||
agent_index = experience.agent_index
|
agent_index = experience.agent_index
|
||||||
@ -2489,6 +2499,7 @@ class DynamicsWorldModel(Module):
|
|||||||
if use_pmpo:
|
if use_pmpo:
|
||||||
pos_advantage_mask = advantage >= 0.
|
pos_advantage_mask = advantage >= 0.
|
||||||
neg_advantage_mask = ~pos_advantage_mask
|
neg_advantage_mask = ~pos_advantage_mask
|
||||||
|
|
||||||
else:
|
else:
|
||||||
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
||||||
|
|
||||||
@ -2552,6 +2563,25 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
policy_loss = -(α * pos + (1. - α) * neg)
|
policy_loss = -(α * pos + (1. - α) * neg)
|
||||||
|
|
||||||
|
# take care of kl
|
||||||
|
|
||||||
|
if self.pmpo_kl_div_loss_weight > 0.:
|
||||||
|
new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0)
|
||||||
|
|
||||||
|
discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(new_unembedded_actions, old_action_unembeds)
|
||||||
|
|
||||||
|
# accumulate discrete and continuous kl div
|
||||||
|
|
||||||
|
kl_div_loss = 0.
|
||||||
|
|
||||||
|
if exists(discrete_kl_div):
|
||||||
|
kl_div_loss = kl_div_loss + discrete_kl_div[mask].mean()
|
||||||
|
|
||||||
|
if exists(continuous_kl_div):
|
||||||
|
kl_div_loss = kl_div_loss + continuous_kl_div[mask].mean()
|
||||||
|
|
||||||
|
policy_loss = policy_loss + kl_div_loss * self.pmpo_kl_div_loss_weight
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# ppo clipped surrogate loss
|
# ppo clipped surrogate loss
|
||||||
|
|
||||||
@ -2694,6 +2724,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
acc_agent_embed = None
|
acc_agent_embed = None
|
||||||
|
|
||||||
|
# maybe store old actions for kl
|
||||||
|
|
||||||
|
acc_policy_embed = None
|
||||||
|
|
||||||
# maybe return rewards
|
# maybe return rewards
|
||||||
|
|
||||||
decoded_rewards = None
|
decoded_rewards = None
|
||||||
@ -2818,6 +2852,13 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
policy_embed = self.policy_head(one_agent_embed)
|
policy_embed = self.policy_head(one_agent_embed)
|
||||||
|
|
||||||
|
# maybe store old actions
|
||||||
|
|
||||||
|
if store_old_action_unembeds:
|
||||||
|
acc_policy_embed = safe_cat((acc_policy_embed, policy_embed))
|
||||||
|
|
||||||
|
# sample actions
|
||||||
|
|
||||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||||
|
|
||||||
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
decoded_discrete_actions = safe_cat((decoded_discrete_actions, sampled_discrete_actions), dim = 1)
|
||||||
@ -2906,6 +2947,7 @@ class DynamicsWorldModel(Module):
|
|||||||
video = video,
|
video = video,
|
||||||
proprio = proprio if has_proprio else None,
|
proprio = proprio if has_proprio else None,
|
||||||
agent_embed = acc_agent_embed if store_agent_embed else None,
|
agent_embed = acc_agent_embed if store_agent_embed else None,
|
||||||
|
old_action_unembeds = self.action_embedder.unembed(acc_policy_embed, pred_head_index = 0) if store_old_action_unembeds else None,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
agent_index = agent_index,
|
agent_index = agent_index,
|
||||||
lens = experience_lens,
|
lens = experience_lens,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.92"
|
version = "0.0.93"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user