From d756d1bb8c1d13e8e0b64fe1f002d54ab76c14cc Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 31 Oct 2025 08:37:39 -0700 Subject: [PATCH] addressing issues raised by an independent researcher with llm assistance --- dreamer4/dreamer4.py | 26 +++++++++++++------------- pyproject.toml | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 340ecc4..b9885ec 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1331,6 +1331,12 @@ class Attention(Module): q = self.q_heads_rmsnorm(q) k = self.k_heads_rmsnorm(k) + # rotary + + if exists(rotary_pos_emb): + q = apply_rotations(rotary_pos_emb, q) + k = apply_rotations(rotary_pos_emb, k) + # caching if exists(kv_cache): @@ -1338,12 +1344,6 @@ class Attention(Module): k = cat((ck, k), dim = -2) v = cat((cv, v), dim = -2) - # rotary - - if exists(rotary_pos_emb): - q = apply_rotations(rotary_pos_emb, q) - k = apply_rotations(rotary_pos_emb, k) - # attention attend_fn = default(attend_fn, naive_attend) @@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module): has_kv_cache = exists(kv_cache) - if has_kv_cache: past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] rotary_seq_len = 1 - rotary_pos_offset = past_tokens.shape[-2] + rotary_pos_offset = past_tokens.shape[1] else: rotary_seq_len = time rotary_pos_offset = 0 @@ -1687,6 +1686,7 @@ class VideoTokenizer(Module): time_block_every = time_block_every, num_special_spatial_tokens = num_latent_tokens, num_residual_streams = num_residual_streams, + special_attend_only_itself = True, final_norm = True ) @@ -2456,8 +2456,8 @@ class DynamicsWorldModel(Module): if exists(experience.lens): mask_for_gae = lens_to_mask(experience.lens, time) - rewards = rewards.masked_fill(mask_for_gae, 0.) - old_values = old_values.masked_fill(mask_for_gae, 0.) + rewards = rewards.masked_fill(~mask_for_gae, 0.) + old_values = old_values.masked_fill(~mask_for_gae, 0.) # calculate returns @@ -2492,7 +2492,7 @@ class DynamicsWorldModel(Module): # mean, var - todo - handle distributed - returns_mean, returns_var = returns.mean(), returns.var() + returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var() # ema @@ -3510,7 +3510,7 @@ class DynamicsWorldModel(Module): reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none') - reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.) + reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.) if is_var_len: reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0) @@ -3554,7 +3554,7 @@ class DynamicsWorldModel(Module): discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t') if exists(continuous_actions): - continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len) + continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len) continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...') continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t') diff --git a/pyproject.toml b/pyproject.toml index ab8ddac..30b3009 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.0" +version = "0.1.2" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }