handle the entropies during policy optimization

This commit is contained in:
lucidrains 2025-10-17 08:47:26 -07:00
parent 61773c8219
commit cb416c0d44
3 changed files with 62 additions and 8 deletions

View File

@ -506,7 +506,8 @@ class ActionEmbedder(Module):
continuous_targets = None, # (... na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
parallel_discrete_calc = None
parallel_discrete_calc = None,
return_entropies = False
):
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
@ -515,6 +516,7 @@ class ActionEmbedder(Module):
# discrete
discrete_log_probs = None
discrete_entropies = None
if exists(discrete_targets):
@ -534,7 +536,18 @@ class ActionEmbedder(Module):
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
log_probs = log(nested_logits.softmax(dim = -1))
prob = nested_logits.softmax(dim = -1)
log_probs = log(prob)
# maybe entropy
if return_entropies:
discrete_entropies = (-prob * log_probs).sum(dim = -1, keepdim = True)
discrete_entropies = cat(discrete_entropies.unbind())
discrete_entropies = rearrange(discrete_entropies, '(b na) -> b na', b = batch)
discrete_entropies = inverse_pack_lead_dims(discrete_entropies, '* na')
# back to regular tensor
@ -558,20 +571,30 @@ class ActionEmbedder(Module):
else:
discrete_log_probs = []
discrete_entropies = []
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
one_discrete_probs = one_discrete_action_logit.softmax(dim = -1)
one_discrete_log_probs = log(one_discrete_probs)
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
discrete_log_probs.append(log_prob)
if return_entropies:
entropy = (-one_discrete_probs * one_discrete_log_probs).sum(dim = -1)
discrete_entropies.append(entropy)
discrete_log_probs = cat(discrete_log_probs, dim = -1)
if return_entropies:
discrete_entropies = stack(discrete_entropies, dim = -1)
# continuous
continuous_log_probs = None
continuous_entropies = None
if exists(continuous_targets):
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
@ -580,7 +603,17 @@ class ActionEmbedder(Module):
distr = Normal(mean, std)
continuous_log_probs = distr.log_prob(continuous_targets)
return discrete_log_probs, continuous_log_probs
if return_entropies:
continuous_entropies = distr.entropy()
log_probs = (discrete_log_probs, continuous_log_probs)
if not return_entropies:
return log_probs
entropies = (discrete_entropies, continuous_entropies)
return log_probs, entropies
def forward(
self,
@ -1708,12 +1741,13 @@ class DynamicsWorldModel(Module):
policy_embed = self.policy_head(agent_embed)
log_probs = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions)
log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
# concat discrete and continuous actions into one for optimizing
old_log_probs = safe_cat(old_log_probs, dim = -1)
log_probs = safe_cat(log_probs, dim = -1)
entropies = safe_cat(entropies, dim = -1)
ratio = (log_probs - old_log_probs).exp()
@ -1724,10 +1758,19 @@ class DynamicsWorldModel(Module):
# clipped surrogate loss
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
policy_loss = policy_loss.mean()
# handle entropy loss for naive exploration bonus
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
total_policy_loss = (
policy_loss +
entropy_loss * self.policy_entropy_weight
)
# value loss
value_bins = self.value_head(agent_embed)
@ -1743,7 +1786,7 @@ class DynamicsWorldModel(Module):
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
return policy_loss, value_loss
return total_policy_loss, value_loss
@torch.no_grad()
def generate(

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.29"
version = "0.0.30"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -366,6 +366,17 @@ def test_action_embedder():
assert discrete_log_probs.shape == (2, 3, 2)
assert continuous_log_probs.shape == (2, 3, 2)
_, (discrete_entropies, continuous_entropies) = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions,
parallel_discrete_calc = True,
return_entropies = True
)
assert discrete_entropies.shape == (2, 3, 2)
assert continuous_entropies.shape == (2, 3, 2)
parallel_discrete_log_probs, _ = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,