fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug.
This commit is contained in:
parent
c0a6cd56a1
commit
b0f6b8583d
@ -1179,10 +1179,11 @@ def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = F
|
|||||||
|
|
||||||
def block_mask_special_tokens_right(
|
def block_mask_special_tokens_right(
|
||||||
seq_len,
|
seq_len,
|
||||||
num_tokens
|
num_tokens,
|
||||||
|
special_attend_only_itself = False
|
||||||
):
|
):
|
||||||
def inner(b, h, q, k):
|
def inner(b, h, q, k):
|
||||||
return special_token_mask(q, k, seq_len, num_tokens)
|
return special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself)
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def compose_mask(mask1, mask2):
|
def compose_mask(mask1, mask2):
|
||||||
@ -1493,7 +1494,8 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
# attend functions for space and time
|
# attend functions for space and time
|
||||||
|
|
||||||
use_flex = exists(flex_attention) and tokens.is_cuda
|
has_kv_cache = exists(kv_cache)
|
||||||
|
use_flex = exists(flex_attention) and tokens.is_cuda and not has_kv_cache # KV cache shape breaks flex attention TODO: Fix
|
||||||
|
|
||||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, special_attend_only_itself = self.special_attend_only_itself, device = device)
|
||||||
|
|
||||||
@ -1505,7 +1507,6 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
time_attn_kv_caches = []
|
time_attn_kv_caches = []
|
||||||
|
|
||||||
has_kv_cache = exists(kv_cache)
|
|
||||||
|
|
||||||
if has_kv_cache:
|
if has_kv_cache:
|
||||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||||
@ -1847,7 +1848,7 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
losses = (recon_loss, lpips_loss)
|
losses = (recon_loss, lpips_loss)
|
||||||
|
|
||||||
return total_loss, TokenizerLosses(losses)
|
return total_loss, TokenizerLosses(*losses)
|
||||||
|
|
||||||
# dynamics model, axial space-time transformer
|
# dynamics model, axial space-time transformer
|
||||||
|
|
||||||
@ -2104,7 +2105,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
self.ppo_eps_clip = ppo_eps_clip
|
self.ppo_eps_clip = ppo_eps_clip
|
||||||
self.value_clip = value_clip
|
self.value_clip = value_clip
|
||||||
self.policy_entropy_weight = value_clip
|
self.policy_entropy_weight = policy_entropy_weight
|
||||||
|
|
||||||
# pmpo related
|
# pmpo related
|
||||||
|
|
||||||
@ -2127,7 +2128,7 @@ class DynamicsWorldModel(Module):
|
|||||||
self.flow_loss_normalizer = LossNormalizer(1)
|
self.flow_loss_normalizer = LossNormalizer(1)
|
||||||
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
self.reward_loss_normalizer = LossNormalizer(multi_token_pred_len)
|
||||||
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
self.discrete_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
||||||
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_discrete_actions > 0 else None
|
self.continuous_actions_loss_normalizer = LossNormalizer(multi_token_pred_len) if num_continuous_actions > 0 else None
|
||||||
|
|
||||||
self.latent_flow_loss_weight = latent_flow_loss_weight
|
self.latent_flow_loss_weight = latent_flow_loss_weight
|
||||||
|
|
||||||
@ -2358,6 +2359,9 @@ class DynamicsWorldModel(Module):
|
|||||||
elif len(env_step_out) == 4:
|
elif len(env_step_out) == 4:
|
||||||
next_frame, reward, terminated, truncated = env_step_out
|
next_frame, reward, terminated, truncated = env_step_out
|
||||||
|
|
||||||
|
elif len(env_step_out) == 5:
|
||||||
|
next_frame, reward, terminated, truncated, info = env_step_out
|
||||||
|
|
||||||
# update episode lens
|
# update episode lens
|
||||||
|
|
||||||
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
||||||
@ -3085,8 +3089,8 @@ class DynamicsWorldModel(Module):
|
|||||||
if latents.ndim == 4:
|
if latents.ndim == 4:
|
||||||
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
||||||
|
|
||||||
assert latents.shape[-2:] == self.latent_shape
|
assert latents.shape[-2:] == self.latent_shape, f'latents must have shape {self.latent_shape}, got {latents.shape[-2:]}'
|
||||||
assert latents.shape[2] == self.num_video_views
|
assert latents.shape[2] == self.num_video_views, f'latents must have {self.num_video_views} views, got {latents.shape[2]}'
|
||||||
|
|
||||||
# variables
|
# variables
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user