diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index be2aa43..9d34607 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -156,12 +156,36 @@ class MultiHeadRMSNorm(Module): # masking related # block causal mask (space fully attends within each block, while time is causal) -def flex_block_causal_mask(seq_len, block_size): +def flex_block_causal_mask( + seq_len, + block_size, + num_special_tokens = 0, + prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model + prevent_special_to_modality = False # decoder of tokenizer +): + assert num_special_tokens <= block_size + + # assume special tokens (either latent or agent tokens) are placed at the right hand side + # so [modality] [latents | agent] def create_mask(_, __, qi, ki): q_block_index = qi // block_size k_block_index = ki // block_size - return q_block_index >= k_block_index + + special_token_index_start = block_size - num_special_tokens + + q_is_special = (qi % block_size) >= special_token_index_start + k_is_special = (ki % block_size) >= special_token_index_start + + causal_mask = q_block_index >= k_block_index + + if prevent_modality_to_special: + causal_mask = causal_mask & ~(q_is_special & ~k_is_special) + + if prevent_special_to_modality: + causal_mask = causal_mask & ~(~q_is_special & k_is_special) + + return causal_mask block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) return block_mask