addressing issues raised by an independent researcher with llm assistance

This commit is contained in:
lucidrains 2025-10-31 08:26:33 -07:00
parent 60681fce1d
commit ef367969f8
2 changed files with 13 additions and 13 deletions

View File

@ -1331,6 +1331,12 @@ class Attention(Module):
q = self.q_heads_rmsnorm(q) q = self.q_heads_rmsnorm(q)
k = self.k_heads_rmsnorm(k) k = self.k_heads_rmsnorm(k)
# rotary
if exists(rotary_pos_emb):
q = apply_rotations(rotary_pos_emb, q)
k = apply_rotations(rotary_pos_emb, k)
# caching # caching
if exists(kv_cache): if exists(kv_cache):
@ -1338,12 +1344,6 @@ class Attention(Module):
k = cat((ck, k), dim = -2) k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2) v = cat((cv, v), dim = -2)
# rotary
if exists(rotary_pos_emb):
q = apply_rotations(rotary_pos_emb, q)
k = apply_rotations(rotary_pos_emb, k)
# attention # attention
attend_fn = default(attend_fn, naive_attend) attend_fn = default(attend_fn, naive_attend)
@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module):
has_kv_cache = exists(kv_cache) has_kv_cache = exists(kv_cache)
if has_kv_cache: if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
rotary_seq_len = 1 rotary_seq_len = 1
rotary_pos_offset = past_tokens.shape[-2] rotary_pos_offset = past_tokens.shape[1]
else: else:
rotary_seq_len = time rotary_seq_len = time
rotary_pos_offset = 0 rotary_pos_offset = 0
@ -1687,6 +1686,7 @@ class VideoTokenizer(Module):
time_block_every = time_block_every, time_block_every = time_block_every,
num_special_spatial_tokens = num_latent_tokens, num_special_spatial_tokens = num_latent_tokens,
num_residual_streams = num_residual_streams, num_residual_streams = num_residual_streams,
special_attend_only_itself = True,
final_norm = True final_norm = True
) )
@ -2456,8 +2456,8 @@ class DynamicsWorldModel(Module):
if exists(experience.lens): if exists(experience.lens):
mask_for_gae = lens_to_mask(experience.lens, time) mask_for_gae = lens_to_mask(experience.lens, time)
rewards = rewards.masked_fill(mask_for_gae, 0.) rewards = rewards.masked_fill(~mask_for_gae, 0.)
old_values = old_values.masked_fill(mask_for_gae, 0.) old_values = old_values.masked_fill(~mask_for_gae, 0.)
# calculate returns # calculate returns
@ -2492,7 +2492,7 @@ class DynamicsWorldModel(Module):
# mean, var - todo - handle distributed # mean, var - todo - handle distributed
returns_mean, returns_var = returns.mean(), returns.var() returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
# ema # ema
@ -3510,7 +3510,7 @@ class DynamicsWorldModel(Module):
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none') reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.) reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
if is_var_len: if is_var_len:
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0) reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.1.0" version = "0.1.1"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }