handle the entropies during policy optimization
This commit is contained in:
parent
61773c8219
commit
cb416c0d44
@ -506,7 +506,8 @@ class ActionEmbedder(Module):
|
||||
continuous_targets = None, # (... na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
parallel_discrete_calc = None
|
||||
parallel_discrete_calc = None,
|
||||
return_entropies = False
|
||||
):
|
||||
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
||||
|
||||
@ -515,6 +516,7 @@ class ActionEmbedder(Module):
|
||||
# discrete
|
||||
|
||||
discrete_log_probs = None
|
||||
discrete_entropies = None
|
||||
|
||||
if exists(discrete_targets):
|
||||
|
||||
@ -534,7 +536,18 @@ class ActionEmbedder(Module):
|
||||
|
||||
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
|
||||
|
||||
log_probs = log(nested_logits.softmax(dim = -1))
|
||||
prob = nested_logits.softmax(dim = -1)
|
||||
|
||||
log_probs = log(prob)
|
||||
|
||||
# maybe entropy
|
||||
|
||||
if return_entropies:
|
||||
discrete_entropies = (-prob * log_probs).sum(dim = -1, keepdim = True)
|
||||
discrete_entropies = cat(discrete_entropies.unbind())
|
||||
discrete_entropies = rearrange(discrete_entropies, '(b na) -> b na', b = batch)
|
||||
|
||||
discrete_entropies = inverse_pack_lead_dims(discrete_entropies, '* na')
|
||||
|
||||
# back to regular tensor
|
||||
|
||||
@ -558,20 +571,30 @@ class ActionEmbedder(Module):
|
||||
|
||||
else:
|
||||
discrete_log_probs = []
|
||||
discrete_entropies = []
|
||||
|
||||
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
||||
|
||||
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
|
||||
one_discrete_probs = one_discrete_action_logit.softmax(dim = -1)
|
||||
one_discrete_log_probs = log(one_discrete_probs)
|
||||
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
||||
|
||||
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
||||
discrete_log_probs.append(log_prob)
|
||||
|
||||
if return_entropies:
|
||||
entropy = (-one_discrete_probs * one_discrete_log_probs).sum(dim = -1)
|
||||
discrete_entropies.append(entropy)
|
||||
|
||||
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
||||
|
||||
if return_entropies:
|
||||
discrete_entropies = stack(discrete_entropies, dim = -1)
|
||||
|
||||
# continuous
|
||||
|
||||
continuous_log_probs = None
|
||||
continuous_entropies = None
|
||||
|
||||
if exists(continuous_targets):
|
||||
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
||||
@ -580,7 +603,17 @@ class ActionEmbedder(Module):
|
||||
distr = Normal(mean, std)
|
||||
continuous_log_probs = distr.log_prob(continuous_targets)
|
||||
|
||||
return discrete_log_probs, continuous_log_probs
|
||||
if return_entropies:
|
||||
continuous_entropies = distr.entropy()
|
||||
|
||||
log_probs = (discrete_log_probs, continuous_log_probs)
|
||||
|
||||
if not return_entropies:
|
||||
return log_probs
|
||||
|
||||
entropies = (discrete_entropies, continuous_entropies)
|
||||
|
||||
return log_probs, entropies
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1708,12 +1741,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
policy_embed = self.policy_head(agent_embed)
|
||||
|
||||
log_probs = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions)
|
||||
log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
||||
|
||||
# concat discrete and continuous actions into one for optimizing
|
||||
|
||||
old_log_probs = safe_cat(old_log_probs, dim = -1)
|
||||
log_probs = safe_cat(log_probs, dim = -1)
|
||||
entropies = safe_cat(entropies, dim = -1)
|
||||
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
|
||||
@ -1724,10 +1758,19 @@ class DynamicsWorldModel(Module):
|
||||
# clipped surrogate loss
|
||||
|
||||
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
||||
|
||||
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
||||
|
||||
policy_loss = policy_loss.mean()
|
||||
|
||||
# handle entropy loss for naive exploration bonus
|
||||
|
||||
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
|
||||
|
||||
total_policy_loss = (
|
||||
policy_loss +
|
||||
entropy_loss * self.policy_entropy_weight
|
||||
)
|
||||
|
||||
# value loss
|
||||
|
||||
value_bins = self.value_head(agent_embed)
|
||||
@ -1743,7 +1786,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
||||
|
||||
return policy_loss, value_loss
|
||||
return total_policy_loss, value_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.29"
|
||||
version = "0.0.30"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -366,6 +366,17 @@ def test_action_embedder():
|
||||
assert discrete_log_probs.shape == (2, 3, 2)
|
||||
assert continuous_log_probs.shape == (2, 3, 2)
|
||||
|
||||
_, (discrete_entropies, continuous_entropies) = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
continuous_targets = continuous_actions,
|
||||
parallel_discrete_calc = True,
|
||||
return_entropies = True
|
||||
)
|
||||
|
||||
assert discrete_entropies.shape == (2, 3, 2)
|
||||
assert continuous_entropies.shape == (2, 3, 2)
|
||||
|
||||
parallel_discrete_log_probs, _ = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user