From 0904e224ab7f8c3b966423104e4b6e8b6afb06de Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 30 Oct 2025 08:22:50 -0700 Subject: [PATCH] make the reverse kl optional --- dreamer4/dreamer4.py | 13 +++++++++++-- pyproject.toml | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 6130b7a..dbfa7c4 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1900,6 +1900,7 @@ class DynamicsWorldModel(Module): gae_lambda = 0.95, ppo_eps_clip = 0.2, pmpo_pos_to_neg_weight = 0.5, # pos and neg equal weight + pmpo_reverse_kl = True, pmpo_kl_div_loss_weight = .3, value_clip = 0.4, policy_entropy_weight = .01, @@ -2108,6 +2109,7 @@ class DynamicsWorldModel(Module): self.pmpo_pos_to_neg_weight = pmpo_pos_to_neg_weight self.pmpo_kl_div_loss_weight = pmpo_kl_div_loss_weight + self.pmpo_reverse_kl = pmpo_reverse_kl # rewards related @@ -2578,11 +2580,18 @@ class DynamicsWorldModel(Module): # take care of kl if self.pmpo_kl_div_loss_weight > 0.: + new_unembedded_actions = self.action_embedder.unembed(policy_embed, pred_head_index = 0) - # mentioned that the "reverse direction for the prior KL" was used + kl_div_inputs, kl_div_targets = new_unembedded_actions, old_action_unembeds - discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(old_action_unembeds, new_unembedded_actions) + # mentioned that the "reverse direction for the prior KL" was used + # make optional, as observed instability in toy task + + if self.pmpo_reverse_kl: + kl_div_inputs, kl_div_targets = kl_div_targets, kl_div_inputs + + discrete_kl_div, continuous_kl_div = self.action_embedder.kl_div(kl_div_inputs, kl_div_targets) # accumulate discrete and continuous kl div diff --git a/pyproject.toml b/pyproject.toml index 8693390..fe1d80f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.100" +version = "0.0.101" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }