fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug.

This commit is contained in:
j 2025-11-04 17:29:12 -05:00
parent c0a6cd56a1
commit b0f6b8583d

View File

@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
def block_mask_special_tokens_right(
seq_len,
num_tokens
num_tokens,
special_attend_only_itself = False
):
def inner(b, h, q, k):
return special_token_mask(q, k, seq_len, num_tokens)
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
return inner
def compose_mask(mask1, mask2):
@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module):
# attend functions for space and time
use_flex = exists(flex_attention) and tokens.is_cuda
has_kv_cache = exists(kv_cache)
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
@ -1505,7 +1507,6 @@ class AxialSpaceTimeTransformer(Module):
time_attn_kv_caches = []
has_kv_cache = exists(kv_cache)
if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
@ -1847,7 +1848,7 @@ class VideoTokenizer(Module):
losses = (recon_loss, lpips_loss)
return total_loss, TokenizerLosses(losses)
return total_loss, TokenizerLosses(*losses)
# dynamics model, axial space-time transformer
@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module):
self.ppo_eps_clip = ppo_eps_clip
self.value_clip = value_clip
self.policy_entropy_weight = value_clip
self.policy_entropy_weight = policy_entropy_weight
# pmpo related
@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module):
self.flow_loss_normalizer = LossNormalizer(1)
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
self.latent_flow_loss_weight = latent_flow_loss_weight
@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module):
elif len(env_step_out) == 4:
next_frame, reward, terminated, truncated = env_step_out
elif len(env_step_out) == 5:
next_frame, reward, terminated, truncated, info = env_step_out
# update episode lens
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module):
if latents.ndim == 4:
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
assert latents.shape[-2:] == self.latent_shape
assert latents.shape[2] == self.num_video_views
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
# variables