fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug.
This commit is contained in:
parent
c0a6cd56a1
commit
b0f6b8583d
@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
||||
|
||||
def block_mask_special_tokens_right(
|
||||
seq_len,
|
||||
num_tokens
|
||||
num_tokens,
|
||||
special_attend_only_itself = False
|
||||
):
|
||||
def inner(b, h, q, k):
|
||||
return special_token_mask(q, k, seq_len, num_tokens)
|
||||
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
||||
return inner
|
||||
|
||||
def compose_mask(mask1, mask2):
|
||||
@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
use_flex = exists(flex_attention) and tokens.is_cuda
|
||||
has_kv_cache = exists(kv_cache)
|
||||
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
||||
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
||||
|
||||
@ -1505,7 +1507,6 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
time_attn_kv_caches = []
|
||||
|
||||
has_kv_cache = exists(kv_cache)
|
||||
|
||||
if has_kv_cache:
|
||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||
@ -1847,7 +1848,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
losses = (recon_loss, lpips_loss)
|
||||
|
||||
return total_loss, TokenizerLosses(losses)
|
||||
return total_loss, TokenizerLosses(*losses)
|
||||
|
||||
# dynamics model, axial space-time transformer
|
||||
|
||||
@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
self.ppo_eps_clip = ppo_eps_clip
|
||||
self.value_clip = value_clip
|
||||
self.policy_entropy_weight = value_clip
|
||||
self.policy_entropy_weight = policy_entropy_weight
|
||||
|
||||
# pmpo related
|
||||
|
||||
@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module):
|
||||
self.flow_loss_normalizer = LossNormalizer(1)
|
||||
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
||||
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
||||
|
||||
self.latent_flow_loss_weight = latent_flow_loss_weight
|
||||
|
||||
@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module):
|
||||
elif len(env_step_out) == 4:
|
||||
next_frame, reward, terminated, truncated = env_step_out
|
||||
|
||||
elif len(env_step_out) == 5:
|
||||
next_frame, reward, terminated, truncated, info = env_step_out
|
||||
|
||||
# update episode lens
|
||||
|
||||
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
||||
@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module):
|
||||
if latents.ndim == 4:
|
||||
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
||||
|
||||
assert latents.shape[-2:] == self.latent_shape
|
||||
assert latents.shape[2] == self.num_video_views
|
||||
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
||||
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
||||
|
||||
# variables
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user