From 58a6964dd9ca1b23dc8460ed98427f23ca79b495 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 3 Oct 2025 11:59:07 -0700 Subject: [PATCH] the dynamics model has a spatial attention with a non-causal attention pattern but nothing else attending to agent tokens --- dreamer4/dreamer4.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 8701006..f545b2a 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -252,10 +252,11 @@ class MultiHeadRMSNorm(Module): # masking related # block causal mask (space fully attends within each block, while time is causal) -def flex_block_causal_mask( +def flex_block_mask( seq_len, block_size, num_special_tokens = 0, + is_causal = True, prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model prevent_special_to_modality = False # decoder of tokenizer ): @@ -273,13 +274,16 @@ def flex_block_causal_mask( q_is_special = (qi % block_size) >= special_token_index_start k_is_special = (ki % block_size) >= special_token_index_start - causal_mask = q_block_index >= k_block_index + mask = True + + if is_causal: + mask &= q_block_index >= k_block_index if prevent_modality_to_special: - causal_mask = causal_mask & ~(q_is_special & ~k_is_special) + mask &= ~(q_is_special & ~k_is_special) if prevent_special_to_modality: - causal_mask = causal_mask & ~(~q_is_special & k_is_special) + mask &= ~(~q_is_special & k_is_special) return causal_mask