From fbfd59e42f24703eef772b11b4f5e69faba9679d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 27 Oct 2025 06:09:09 -0700 Subject: [PATCH] handle variable lengthed experiences when doing policy optimization --- dreamer4/dreamer4.py | 31 ++++++++++++++++++++++++++++--- pyproject.toml | 2 +- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 8b507b7..5e0f61c 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -82,7 +82,7 @@ class Experience: log_probs: tuple[Tensor, Tensor] | None = None values: Tensor | None = None step_size: int | None = None - lens: Tensor | None = None, + lens: Tensor | None = None agent_index: int = 0 is_from_world_model: bool = True @@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module): returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated) + # handle variable lengths + + max_time = latents.shape[1] + is_var_len = exists(experience.lens) + + if is_var_len: + lens = experience.lens + mask = lens_to_mask(lens, max_time) + # determine whether to finetune entire transformer or just learn the heads world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext @@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module): # handle entropy loss for naive exploration bonus - entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean() + entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum') total_policy_loss = ( policy_loss + entropy_loss * self.policy_entropy_weight ) + # maybe handle variable lengths + + if is_var_len: + total_policy_loss = total_policy_loss[mask].mean() + else: + total_policy_loss = total_policy_loss.mean() + # maybe take policy optimizer step if exists(policy_optim): @@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module): return_bins = self.reward_encoder(returns) + value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins)) + value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none') value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none') - value_loss = torch.maximum(value_loss_1, value_loss_2).mean() + value_loss = torch.maximum(value_loss_1, value_loss_2) + + # maybe variable length + + if is_var_len: + value_loss = value_loss[mask].mean() + else: + value_loss = value_loss.mean() # maybe take value optimizer step diff --git a/pyproject.toml b/pyproject.toml index 6622e4f..bb2ff1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.76" +version = "0.0.77" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }