some more control over whether to normalize advantages
This commit is contained in:
parent
0904e224ab
commit
3beae186da
@ -1902,6 +1902,7 @@ class DynamicsWorldModel(Module):
|
|||||||
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight
|
||||||
pmpo_reverse_kl = True,
|
pmpo_reverse_kl = True,
|
||||||
pmpo_kl_div_loss_weight = .3,
|
pmpo_kl_div_loss_weight = .3,
|
||||||
|
normalize_advantages = None,
|
||||||
value_clip = 0.4,
|
value_clip = 0.4,
|
||||||
policy_entropy_weight = .01,
|
policy_entropy_weight = .01,
|
||||||
gae_use_accelerated = False
|
gae_use_accelerated = False
|
||||||
@ -2425,6 +2426,7 @@ class DynamicsWorldModel(Module):
|
|||||||
value_optim: Optimizer | None = None,
|
value_optim: Optimizer | None = None,
|
||||||
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
only_learn_policy_value_heads = True, # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||||
use_pmpo = True,
|
use_pmpo = True,
|
||||||
|
normalize_advantages = None,
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -2507,16 +2509,19 @@ class DynamicsWorldModel(Module):
|
|||||||
else:
|
else:
|
||||||
advantage = returns - old_values
|
advantage = returns - old_values
|
||||||
|
|
||||||
# apparently they just use the sign of the advantage
|
# if using pmpo, do not normalize advantages, but can be overridden
|
||||||
|
|
||||||
|
normalize_advantages = default(normalize_advantages, not use_pmpo)
|
||||||
|
|
||||||
|
if normalize_advantages:
|
||||||
|
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
||||||
|
|
||||||
# https://arxiv.org/abs/2410.04166v1
|
# https://arxiv.org/abs/2410.04166v1
|
||||||
|
|
||||||
if use_pmpo:
|
if use_pmpo:
|
||||||
pos_advantage_mask = advantage >= 0.
|
pos_advantage_mask = advantage >= 0.
|
||||||
neg_advantage_mask = ~pos_advantage_mask
|
neg_advantage_mask = ~pos_advantage_mask
|
||||||
|
|
||||||
else:
|
|
||||||
advantage = F.layer_norm(advantage, advantage.shape, eps = eps)
|
|
||||||
|
|
||||||
# replay for the action logits and values
|
# replay for the action logits and values
|
||||||
# but only do so if fine tuning the entire world model for RL
|
# but only do so if fine tuning the entire world model for RL
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.101"
|
version = "0.0.102"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user