handle variable lengthed experiences when doing policy optimization

This commit is contained in:
lucidrains 2025-10-27 06:09:09 -07:00
parent 46432aee9b
commit fbfd59e42f
2 changed files with 29 additions and 4 deletions

View File

@ -82,7 +82,7 @@ class Experience:
log_probs: tuple[Tensor, Tensor] | None = None
values: Tensor | None = None
step_size: int | None = None
lens: Tensor | None = None,
lens: Tensor | None = None
agent_index: int = 0
is_from_world_model: bool = True
@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module):
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
# handle variable lengths
max_time = latents.shape[1]
is_var_len = exists(experience.lens)
if is_var_len:
lens = experience.lens
mask = lens_to_mask(lens, max_time)
# determine whether to finetune entire transformer or just learn the heads
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module):
# handle entropy loss for naive exploration bonus
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
total_policy_loss = (
policy_loss +
entropy_loss * self.policy_entropy_weight
)
# maybe handle variable lengths
if is_var_len:
total_policy_loss = total_policy_loss[mask].mean()
else:
total_policy_loss = total_policy_loss.mean()
# maybe take policy optimizer step
if exists(policy_optim):
@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module):
return_bins = self.reward_encoder(returns)
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
value_loss = torch.maximum(value_loss_1, value_loss_2)
# maybe variable length
if is_var_len:
value_loss = value_loss[mask].mean()
else:
value_loss = value_loss.mean()
# maybe take value optimizer step

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.76"
version = "0.0.77"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }