From cb416c0d44cf2679c2a301819e546788120fe408 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 17 Oct 2025 08:47:26 -0700 Subject: [PATCH] handle the entropies during policy optimization --- dreamer4/dreamer4.py | 57 +++++++++++++++++++++++++++++++++++++------ pyproject.toml | 2 +- tests/test_dreamer.py | 11 +++++++++ 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 56ec538..f2eaec2 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -506,7 +506,8 @@ class ActionEmbedder(Module): continuous_targets = None, # (... na) discrete_action_types = None, # (na) continuous_action_types = None, # (na) - parallel_discrete_calc = None + parallel_discrete_calc = None, + return_entropies = False ): parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1) @@ -515,6 +516,7 @@ class ActionEmbedder(Module): # discrete discrete_log_probs = None + discrete_entropies = None if exists(discrete_targets): @@ -534,7 +536,18 @@ class ActionEmbedder(Module): nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True) - log_probs = log(nested_logits.softmax(dim = -1)) + prob = nested_logits.softmax(dim = -1) + + log_probs = log(prob) + + # maybe entropy + + if return_entropies: + discrete_entropies = (-prob * log_probs).sum(dim = -1, keepdim = True) + discrete_entropies = cat(discrete_entropies.unbind()) + discrete_entropies = rearrange(discrete_entropies, '(b na) -> b na', b = batch) + + discrete_entropies = inverse_pack_lead_dims(discrete_entropies, '* na') # back to regular tensor @@ -558,20 +571,30 @@ class ActionEmbedder(Module): else: discrete_log_probs = [] + discrete_entropies = [] for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)): - one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1) + one_discrete_probs = one_discrete_action_logit.softmax(dim = -1) + one_discrete_log_probs = log(one_discrete_probs) one_discrete_target = rearrange(one_discrete_target, '... -> ... 1') log_prob = one_discrete_log_probs.gather(-1, one_discrete_target) discrete_log_probs.append(log_prob) + if return_entropies: + entropy = (-one_discrete_probs * one_discrete_log_probs).sum(dim = -1) + discrete_entropies.append(entropy) + discrete_log_probs = cat(discrete_log_probs, dim = -1) + if return_entropies: + discrete_entropies = stack(discrete_entropies, dim = -1) + # continuous continuous_log_probs = None + continuous_entropies = None if exists(continuous_targets): mean, log_var = continuous_action_mean_log_var.unbind(dim = -1) @@ -580,7 +603,17 @@ class ActionEmbedder(Module): distr = Normal(mean, std) continuous_log_probs = distr.log_prob(continuous_targets) - return discrete_log_probs, continuous_log_probs + if return_entropies: + continuous_entropies = distr.entropy() + + log_probs = (discrete_log_probs, continuous_log_probs) + + if not return_entropies: + return log_probs + + entropies = (discrete_entropies, continuous_entropies) + + return log_probs, entropies def forward( self, @@ -1708,12 +1741,13 @@ class DynamicsWorldModel(Module): policy_embed = self.policy_head(agent_embed) - log_probs = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions) + log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True) # concat discrete and continuous actions into one for optimizing old_log_probs = safe_cat(old_log_probs, dim = -1) log_probs = safe_cat(log_probs, dim = -1) + entropies = safe_cat(entropies, dim = -1) ratio = (log_probs - old_log_probs).exp() @@ -1724,10 +1758,19 @@ class DynamicsWorldModel(Module): # clipped surrogate loss policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage) - policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum') + policy_loss = policy_loss.mean() + # handle entropy loss for naive exploration bonus + + entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean() + + total_policy_loss = ( + policy_loss + + entropy_loss * self.policy_entropy_weight + ) + # value loss value_bins = self.value_head(agent_embed) @@ -1743,7 +1786,7 @@ class DynamicsWorldModel(Module): value_loss = torch.maximum(value_loss_1, value_loss_2).mean() - return policy_loss, value_loss + return total_policy_loss, value_loss @torch.no_grad() def generate( diff --git a/pyproject.toml b/pyproject.toml index 1aed066..f369bac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.29" +version = "0.0.30" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 2b6aba3..ac983b2 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -366,6 +366,17 @@ def test_action_embedder(): assert discrete_log_probs.shape == (2, 3, 2) assert continuous_log_probs.shape == (2, 3, 2) + _, (discrete_entropies, continuous_entropies) = embedder.log_probs( + action_embed, + discrete_targets = discrete_actions, + continuous_targets = continuous_actions, + parallel_discrete_calc = True, + return_entropies = True + ) + + assert discrete_entropies.shape == (2, 3, 2) + assert continuous_entropies.shape == (2, 3, 2) + parallel_discrete_log_probs, _ = embedder.log_probs( action_embed, discrete_targets = discrete_actions,