addressing issues raised by an independent researcher with llm assistance

This commit is contained in:
lucidrains 2025-10-31 08:37:39 -07:00
parent 60681fce1d
commit d756d1bb8c
2 changed files with 14 additions and 14 deletions

View File

@ -1331,6 +1331,12 @@ class Attention(Module):
q = self.q_heads_rmsnorm(q)
k = self.k_heads_rmsnorm(k)
# rotary
if exists(rotary_pos_emb):
q = apply_rotations(rotary_pos_emb, q)
k = apply_rotations(rotary_pos_emb, k)
# caching
if exists(kv_cache):
@ -1338,12 +1344,6 @@ class Attention(Module):
k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)
# rotary
if exists(rotary_pos_emb):
q = apply_rotations(rotary_pos_emb, q)
k = apply_rotations(rotary_pos_emb, k)
# attention
attend_fn = default(attend_fn, naive_attend)
@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module):
has_kv_cache = exists(kv_cache)
if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
rotary_seq_len = 1
rotary_pos_offset = past_tokens.shape[-2]
rotary_pos_offset = past_tokens.shape[1]
else:
rotary_seq_len = time
rotary_pos_offset = 0
@ -1687,6 +1686,7 @@ class VideoTokenizer(Module):
time_block_every = time_block_every,
num_special_spatial_tokens = num_latent_tokens,
num_residual_streams = num_residual_streams,
special_attend_only_itself = True,
final_norm = True
)
@ -2456,8 +2456,8 @@ class DynamicsWorldModel(Module):
if exists(experience.lens):
mask_for_gae = lens_to_mask(experience.lens, time)
rewards = rewards.masked_fill(mask_for_gae, 0.)
old_values = old_values.masked_fill(mask_for_gae, 0.)
rewards = rewards.masked_fill(~mask_for_gae, 0.)
old_values = old_values.masked_fill(~mask_for_gae, 0.)
# calculate returns
@ -2492,7 +2492,7 @@ class DynamicsWorldModel(Module):
# mean, var - todo - handle distributed
returns_mean, returns_var = returns.mean(), returns.var()
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
# ema
@ -3510,7 +3510,7 @@ class DynamicsWorldModel(Module):
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
if is_var_len:
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
@ -3554,7 +3554,7 @@ class DynamicsWorldModel(Module):
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
if exists(continuous_actions):
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.0"
version = "0.1.2"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }