handle the entropies during policy optimization
This commit is contained in:
parent
61773c8219
commit
cb416c0d44
@ -506,7 +506,8 @@ class ActionEmbedder(Module):
|
|||||||
continuous_targets = None, # (... na)
|
continuous_targets = None, # (... na)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
parallel_discrete_calc = None
|
parallel_discrete_calc = None,
|
||||||
|
return_entropies = False
|
||||||
):
|
):
|
||||||
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
parallel_discrete_calc = default(parallel_discrete_calc, exists(discrete_targets) and discrete_targets.shape[-1] > 1)
|
||||||
|
|
||||||
@ -515,6 +516,7 @@ class ActionEmbedder(Module):
|
|||||||
# discrete
|
# discrete
|
||||||
|
|
||||||
discrete_log_probs = None
|
discrete_log_probs = None
|
||||||
|
discrete_entropies = None
|
||||||
|
|
||||||
if exists(discrete_targets):
|
if exists(discrete_targets):
|
||||||
|
|
||||||
@ -534,7 +536,18 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
|
nested_logits = nested_tensor(discrete_action_logits.split(jagged_dims * batch), layout = torch.jagged, device = self.device, requires_grad = True)
|
||||||
|
|
||||||
log_probs = log(nested_logits.softmax(dim = -1))
|
prob = nested_logits.softmax(dim = -1)
|
||||||
|
|
||||||
|
log_probs = log(prob)
|
||||||
|
|
||||||
|
# maybe entropy
|
||||||
|
|
||||||
|
if return_entropies:
|
||||||
|
discrete_entropies = (-prob * log_probs).sum(dim = -1, keepdim = True)
|
||||||
|
discrete_entropies = cat(discrete_entropies.unbind())
|
||||||
|
discrete_entropies = rearrange(discrete_entropies, '(b na) -> b na', b = batch)
|
||||||
|
|
||||||
|
discrete_entropies = inverse_pack_lead_dims(discrete_entropies, '* na')
|
||||||
|
|
||||||
# back to regular tensor
|
# back to regular tensor
|
||||||
|
|
||||||
@ -558,20 +571,30 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
discrete_log_probs = []
|
discrete_log_probs = []
|
||||||
|
discrete_entropies = []
|
||||||
|
|
||||||
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
for one_discrete_action_logit, one_discrete_target in zip(discrete_action_logits, discrete_targets.unbind(dim = -1)):
|
||||||
|
|
||||||
one_discrete_log_probs = one_discrete_action_logit.log_softmax(dim = -1)
|
one_discrete_probs = one_discrete_action_logit.softmax(dim = -1)
|
||||||
|
one_discrete_log_probs = log(one_discrete_probs)
|
||||||
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
one_discrete_target = rearrange(one_discrete_target, '... -> ... 1')
|
||||||
|
|
||||||
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
log_prob = one_discrete_log_probs.gather(-1, one_discrete_target)
|
||||||
discrete_log_probs.append(log_prob)
|
discrete_log_probs.append(log_prob)
|
||||||
|
|
||||||
|
if return_entropies:
|
||||||
|
entropy = (-one_discrete_probs * one_discrete_log_probs).sum(dim = -1)
|
||||||
|
discrete_entropies.append(entropy)
|
||||||
|
|
||||||
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
discrete_log_probs = cat(discrete_log_probs, dim = -1)
|
||||||
|
|
||||||
|
if return_entropies:
|
||||||
|
discrete_entropies = stack(discrete_entropies, dim = -1)
|
||||||
|
|
||||||
# continuous
|
# continuous
|
||||||
|
|
||||||
continuous_log_probs = None
|
continuous_log_probs = None
|
||||||
|
continuous_entropies = None
|
||||||
|
|
||||||
if exists(continuous_targets):
|
if exists(continuous_targets):
|
||||||
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
|
||||||
@ -580,7 +603,17 @@ class ActionEmbedder(Module):
|
|||||||
distr = Normal(mean, std)
|
distr = Normal(mean, std)
|
||||||
continuous_log_probs = distr.log_prob(continuous_targets)
|
continuous_log_probs = distr.log_prob(continuous_targets)
|
||||||
|
|
||||||
return discrete_log_probs, continuous_log_probs
|
if return_entropies:
|
||||||
|
continuous_entropies = distr.entropy()
|
||||||
|
|
||||||
|
log_probs = (discrete_log_probs, continuous_log_probs)
|
||||||
|
|
||||||
|
if not return_entropies:
|
||||||
|
return log_probs
|
||||||
|
|
||||||
|
entropies = (discrete_entropies, continuous_entropies)
|
||||||
|
|
||||||
|
return log_probs, entropies
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1708,12 +1741,13 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
policy_embed = self.policy_head(agent_embed)
|
policy_embed = self.policy_head(agent_embed)
|
||||||
|
|
||||||
log_probs = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions)
|
log_probs, entropies = self.action_embedder.log_probs(policy_embed, discrete_targets = discrete_actions, continuous_targets = continuous_actions, return_entropies = True)
|
||||||
|
|
||||||
# concat discrete and continuous actions into one for optimizing
|
# concat discrete and continuous actions into one for optimizing
|
||||||
|
|
||||||
old_log_probs = safe_cat(old_log_probs, dim = -1)
|
old_log_probs = safe_cat(old_log_probs, dim = -1)
|
||||||
log_probs = safe_cat(log_probs, dim = -1)
|
log_probs = safe_cat(log_probs, dim = -1)
|
||||||
|
entropies = safe_cat(entropies, dim = -1)
|
||||||
|
|
||||||
ratio = (log_probs - old_log_probs).exp()
|
ratio = (log_probs - old_log_probs).exp()
|
||||||
|
|
||||||
@ -1724,10 +1758,19 @@ class DynamicsWorldModel(Module):
|
|||||||
# clipped surrogate loss
|
# clipped surrogate loss
|
||||||
|
|
||||||
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
|
||||||
|
|
||||||
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
policy_loss = reduce(policy_loss, 'b t na -> b t', 'sum')
|
||||||
|
|
||||||
policy_loss = policy_loss.mean()
|
policy_loss = policy_loss.mean()
|
||||||
|
|
||||||
|
# handle entropy loss for naive exploration bonus
|
||||||
|
|
||||||
|
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
|
||||||
|
|
||||||
|
total_policy_loss = (
|
||||||
|
policy_loss +
|
||||||
|
entropy_loss * self.policy_entropy_weight
|
||||||
|
)
|
||||||
|
|
||||||
# value loss
|
# value loss
|
||||||
|
|
||||||
value_bins = self.value_head(agent_embed)
|
value_bins = self.value_head(agent_embed)
|
||||||
@ -1743,7 +1786,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
||||||
|
|
||||||
return policy_loss, value_loss
|
return total_policy_loss, value_loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.29"
|
version = "0.0.30"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -366,6 +366,17 @@ def test_action_embedder():
|
|||||||
assert discrete_log_probs.shape == (2, 3, 2)
|
assert discrete_log_probs.shape == (2, 3, 2)
|
||||||
assert continuous_log_probs.shape == (2, 3, 2)
|
assert continuous_log_probs.shape == (2, 3, 2)
|
||||||
|
|
||||||
|
_, (discrete_entropies, continuous_entropies) = embedder.log_probs(
|
||||||
|
action_embed,
|
||||||
|
discrete_targets = discrete_actions,
|
||||||
|
continuous_targets = continuous_actions,
|
||||||
|
parallel_discrete_calc = True,
|
||||||
|
return_entropies = True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert discrete_entropies.shape == (2, 3, 2)
|
||||||
|
assert continuous_entropies.shape == (2, 3, 2)
|
||||||
|
|
||||||
parallel_discrete_log_probs, _ = embedder.log_probs(
|
parallel_discrete_log_probs, _ = embedder.log_probs(
|
||||||
action_embed,
|
action_embed,
|
||||||
discrete_targets = discrete_actions,
|
discrete_targets = discrete_actions,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user