handle variable lengthed experiences when doing policy optimization
This commit is contained in:
parent
46432aee9b
commit
fbfd59e42f
@ -82,7 +82,7 @@ class Experience:
|
|||||||
log_probs: tuple[Tensor, Tensor] | None = None
|
log_probs: tuple[Tensor, Tensor] | None = None
|
||||||
values: Tensor | None = None
|
values: Tensor | None = None
|
||||||
step_size: int | None = None
|
step_size: int | None = None
|
||||||
lens: Tensor | None = None,
|
lens: Tensor | None = None
|
||||||
agent_index: int = 0
|
agent_index: int = 0
|
||||||
is_from_world_model: bool = True
|
is_from_world_model: bool = True
|
||||||
|
|
||||||
@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
||||||
|
|
||||||
|
# handle variable lengths
|
||||||
|
|
||||||
|
max_time = latents.shape[1]
|
||||||
|
is_var_len = exists(experience.lens)
|
||||||
|
|
||||||
|
if is_var_len:
|
||||||
|
lens = experience.lens
|
||||||
|
mask = lens_to_mask(lens, max_time)
|
||||||
|
|
||||||
# determine whether to finetune entire transformer or just learn the heads
|
# determine whether to finetune entire transformer or just learn the heads
|
||||||
|
|
||||||
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
||||||
@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# handle entropy loss for naive exploration bonus
|
# handle entropy loss for naive exploration bonus
|
||||||
|
|
||||||
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
|
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
||||||
|
|
||||||
total_policy_loss = (
|
total_policy_loss = (
|
||||||
policy_loss +
|
policy_loss +
|
||||||
entropy_loss * self.policy_entropy_weight
|
entropy_loss * self.policy_entropy_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# maybe handle variable lengths
|
||||||
|
|
||||||
|
if is_var_len:
|
||||||
|
total_policy_loss = total_policy_loss[mask].mean()
|
||||||
|
else:
|
||||||
|
total_policy_loss = total_policy_loss.mean()
|
||||||
|
|
||||||
# maybe take policy optimizer step
|
# maybe take policy optimizer step
|
||||||
|
|
||||||
if exists(policy_optim):
|
if exists(policy_optim):
|
||||||
@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
return_bins = self.reward_encoder(returns)
|
return_bins = self.reward_encoder(returns)
|
||||||
|
|
||||||
|
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
||||||
|
|
||||||
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
||||||
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
||||||
|
|
||||||
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
||||||
|
|
||||||
|
# maybe variable length
|
||||||
|
|
||||||
|
if is_var_len:
|
||||||
|
value_loss = value_loss[mask].mean()
|
||||||
|
else:
|
||||||
|
value_loss = value_loss.mean()
|
||||||
|
|
||||||
# maybe take value optimizer step
|
# maybe take value optimizer step
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.76"
|
version = "0.0.77"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user