cover the attention masking for tokenizer encoder, decoder, as well as dynamics model (latent and agent tokens are "special" and placed on the right)

This commit is contained in:
lucidrains 2025-10-01 12:11:06 -07:00
parent c18c624be6
commit 8e7a35b89c

View File

@ -156,12 +156,36 @@ class MultiHeadRMSNorm(Module):
# masking related
# block causal mask (space fully attends within each block, while time is causal)
def flex_block_causal_mask(seq_len, block_size):
def flex_block_causal_mask(
seq_len,
block_size,
num_special_tokens = 0,
prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model
prevent_special_to_modality = False # decoder of tokenizer
):
assert num_special_tokens <= block_size
# assume special tokens (either latent or agent tokens) are placed at the right hand side
# so [modality] [latents | agent]
def create_mask(_, __, qi, ki):
q_block_index = qi // block_size
k_block_index = ki // block_size
return q_block_index >= k_block_index
special_token_index_start = block_size - num_special_tokens
q_is_special = (qi % block_size) >= special_token_index_start
k_is_special = (ki % block_size) >= special_token_index_start
causal_mask = q_block_index >= k_block_index
if prevent_modality_to_special:
causal_mask = causal_mask & ~(q_is_special & ~k_is_special)
if prevent_special_to_modality:
causal_mask = causal_mask & ~(~q_is_special & k_is_special)
return causal_mask
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask