From b0f6b8583dd96f4624322fe0cea0d3dc5e83fc89 Mon Sep 17 00:00:00 2001 From: j Date: Tue, 4 Nov 2025 17:29:12 -0500 Subject: [PATCH] fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug. --- dreamer4/dreamer4.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index b9885ec..346e9ce 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F def block_mask_special_tokens_right( seq_len, - num_tokens + num_tokens, + special_attend_only_itself = False ): def inner(b, h, q, k): - return special_token_mask(q, k, seq_len, num_tokens) + return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself) return inner def compose_mask(mask1, mask2): @@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module): # attend functions for space and time - use_flex = exists(flex_attention) and tokens.is_cuda + has_kv_cache = exists(kv_cache) + use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device) @@ -1505,7 +1507,6 @@ class AxialSpaceTimeTransformer(Module): time_attn_kv_caches = [] - has_kv_cache = exists(kv_cache) if has_kv_cache: past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] @@ -1847,7 +1848,7 @@ class VideoTokenizer(Module): losses = (recon_loss, lpips_loss) - return total_loss, TokenizerLosses(losses) + return total_loss, TokenizerLosses(*losses) # dynamics model, axial space-time transformer @@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module): self.ppo_eps_clip = ppo_eps_clip self.value_clip = value_clip - self.policy_entropy_weight = value_clip + self.policy_entropy_weight = policy_entropy_weight # pmpo related @@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module): self.flow_loss_normalizer = LossNormalizer(1) self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len) self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None - self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None + self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None self.latent_flow_loss_weight = latent_flow_loss_weight @@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module): elif len(env_step_out) == 4: next_frame, reward, terminated, truncated = env_step_out + elif len(env_step_out) == 5: + next_frame, reward, terminated, truncated, info = env_step_out + # update episode lens episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1) @@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module): if latents.ndim == 4: latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case - assert latents.shape[-2:] == self.latent_shape - assert latents.shape[2] == self.num_video_views + assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}' + assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}' # variables