complete all the types of attention masking patterns as proposed in the paper

This commit is contained in:
lucidrains 2025-10-04 12:45:54 -07:00
parent 5c6be4d979
commit 971637673b
2 changed files with 48 additions and 57 deletions

View File

@ -376,47 +376,6 @@ class MultiHeadRMSNorm(Module):
scale = (self.gamma + 1.) * self.scale
return einx.multiply('... h n d, h d', normed, scale)
# masking related
# block causal mask (space fully attends within each block, while time is causal)
def flex_block_mask(
seq_len,
block_size,
num_special_tokens = 0,
is_causal = True,
prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model
prevent_special_to_modality = False # decoder of tokenizer
):
assert num_special_tokens <= block_size
# assume special tokens (either latent or agent tokens) are placed at the right hand side
# so [modality] [latents | agent]
def create_mask(b, __, qi, ki):
q_block_index = qi // block_size
k_block_index = ki // block_size
special_token_index_start = block_size - num_special_tokens
q_is_special = (qi % block_size) >= special_token_index_start
k_is_special = (ki % block_size) >= special_token_index_start
mask = b >= -1 # make shift True tensor
if is_causal:
mask &= q_block_index >= k_block_index
if prevent_modality_to_special:
mask &= ~(q_is_special & ~k_is_special)
if prevent_special_to_modality:
mask &= ~(~q_is_special & k_is_special)
return mask
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask
# naive attend
def naive_attend(
@ -493,20 +452,31 @@ def block_mask_causal(block_size):
bq = q // block_size
bk = k // block_size
return bq >= bk
return inner
def agent_token_mask(q, k, seq_len, num_tokens):
def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
bq = q % seq_len
bk = k % seq_len
is_special_start_index = seq_len - num_tokens
q_is_special = q >= is_special_start_index
k_is_special = k >= is_special_start_index
return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
if special_attend_only_itself:
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
else:
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
return out
def block_mask_special_tokens_right(
seq_len,
num_tokens
):
def inner(b, h, q, k):
return agent_token_mask(q, k, seq_len, num_tokens)
return special_token_mask(q, k, seq_len, num_tokens)
return inner
def compose_mask(mask1, mask2):
@ -539,17 +509,21 @@ def get_attend_fn(
causal = False,
causal_block_size = 1,
softclamp_value = 50.,
num_agent_tokens = 0,
num_special_tokens = 0, # special tokens are latents / agents
block_size_per_special = None, # defaults to k_seq_len
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
device = None
):
block_size_per_special = default(block_size_per_special, k_seq_len)
if use_flex:
# flex pathway
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
if num_agent_tokens > 0:
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
block_mask_fn = compose_mask(block_mask_fn, agent_block_mask)
if num_special_tokens > 0:
special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
@ -559,11 +533,11 @@ def get_attend_fn(
# naive pathway
mask = None
if num_agent_tokens > 0:
if num_special_tokens > 0:
q_seq = torch.arange(seq_len, device = device)[:, None]
k_seq = torch.arange(k_seq_len, device = device)[None, :]
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
@ -890,17 +864,24 @@ class VideoTokenizer(Module):
attend_kwargs = dict(
causal = True,
causal_block_size = space_seq_len,
softclamp_value = self.attn_softclamp_value
softclamp_value = self.attn_softclamp_value,
block_size_per_special = space_seq_len,
num_special_tokens = 1
)
use_flex = tokens.is_cuda and exists(flex_attention)
attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
# encoder attend
# modality can only attend to itself while latents can attend to everything
# similar to agent token in dynamics model
encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
# encoder
for attn, ff in self.encoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
tokens = ff(tokens) + tokens
tokens = self.encoder_norm(tokens)
@ -930,10 +911,14 @@ class VideoTokenizer(Module):
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
# decoder attend
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
# decoder attention
for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
tokens = ff(tokens) + tokens
@ -1172,7 +1157,7 @@ class DynamicsModel(Module):
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)

View File

@ -64,11 +64,15 @@ def test_symexp_two_hot():
@param('softclamp_value', (50., None))
@param('num_agent_tokens', (0, 1))
@param('causal_block_size', (1, 8))
@param('block_size_per_special', (1, 8))
@param('special_attend_only_itself', (False, True))
def test_attend_factory(
causal,
softclamp_value,
num_agent_tokens,
causal_block_size
causal_block_size,
block_size_per_special,
special_attend_only_itself
):
from dreamer4.dreamer4 import get_attend_fn
@ -84,7 +88,9 @@ def test_attend_factory(
causal_block_size = causal_block_size,
softclamp_value = softclamp_value,
device = q.device,
num_agent_tokens = num_agent_tokens
num_agent_tokens = num_agent_tokens,
block_size_per_special = block_size_per_special,
special_attend_only_itself = special_attend_only_itself
)
attend = get_attend_fn(True, **attend_kwargs)