complete all the types of attention masking patterns as proposed in the paper

This commit is contained in:
lucidrains 2025-10-04 12:45:54 -07:00
parent 5c6be4d979
commit 971637673b
2 changed files with 48 additions and 57 deletions

View File

@ -376,47 +376,6 @@ class MultiHeadRMSNorm(Module):
scale = (self.gamma + 1.) * self.scale scale = (self.gamma + 1.) * self.scale
return einx.multiply('... h n d, h d', normed, scale) return einx.multiply('... h n d, h d', normed, scale)
# masking related
# block causal mask (space fully attends within each block, while time is causal)
def flex_block_mask(
seq_len,
block_size,
num_special_tokens = 0,
is_causal = True,
prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model
prevent_special_to_modality = False # decoder of tokenizer
):
assert num_special_tokens <= block_size
# assume special tokens (either latent or agent tokens) are placed at the right hand side
# so [modality] [latents | agent]
def create_mask(b, __, qi, ki):
q_block_index = qi // block_size
k_block_index = ki // block_size
special_token_index_start = block_size - num_special_tokens
q_is_special = (qi % block_size) >= special_token_index_start
k_is_special = (ki % block_size) >= special_token_index_start
mask = b >= -1 # make shift True tensor
if is_causal:
mask &= q_block_index >= k_block_index
if prevent_modality_to_special:
mask &= ~(q_is_special & ~k_is_special)
if prevent_special_to_modality:
mask &= ~(~q_is_special & k_is_special)
return mask
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask
# naive attend # naive attend
def naive_attend( def naive_attend(
@ -493,20 +452,31 @@ def block_mask_causal(block_size):
bq = q // block_size bq = q // block_size
bk = k // block_size bk = k // block_size
return bq >= bk return bq >= bk
return inner return inner
def agent_token_mask(q, k, seq_len, num_tokens): def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
bq = q % seq_len
bk = k % seq_len
is_special_start_index = seq_len - num_tokens is_special_start_index = seq_len - num_tokens
q_is_special = q >= is_special_start_index q_is_special = q >= is_special_start_index
k_is_special = k >= is_special_start_index k_is_special = k >= is_special_start_index
return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
if special_attend_only_itself:
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
else:
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
return out
def block_mask_special_tokens_right( def block_mask_special_tokens_right(
seq_len, seq_len,
num_tokens num_tokens
): ):
def inner(b, h, q, k): def inner(b, h, q, k):
return agent_token_mask(q, k, seq_len, num_tokens) return special_token_mask(q, k, seq_len, num_tokens)
return inner return inner
def compose_mask(mask1, mask2): def compose_mask(mask1, mask2):
@ -539,17 +509,21 @@ def get_attend_fn(
causal = False, causal = False,
causal_block_size = 1, causal_block_size = 1,
softclamp_value = 50., softclamp_value = 50.,
num_agent_tokens = 0, num_special_tokens = 0, # special tokens are latents / agents
block_size_per_special = None, # defaults to k_seq_len
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
device = None device = None
): ):
block_size_per_special = default(block_size_per_special, k_seq_len)
if use_flex: if use_flex:
# flex pathway # flex pathway
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
if num_agent_tokens > 0: if num_special_tokens > 0:
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens) special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
block_mask_fn = compose_mask(block_mask_fn, agent_block_mask) block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len) block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
@ -559,11 +533,11 @@ def get_attend_fn(
# naive pathway # naive pathway
mask = None mask = None
if num_agent_tokens > 0: if num_special_tokens > 0:
q_seq = torch.arange(seq_len, device = device)[:, None] q_seq = torch.arange(seq_len, device = device)[:, None]
k_seq = torch.arange(k_seq_len, device = device)[None, :] k_seq = torch.arange(k_seq_len, device = device)[None, :]
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens) mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value) attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
@ -890,17 +864,24 @@ class VideoTokenizer(Module):
attend_kwargs = dict( attend_kwargs = dict(
causal = True, causal = True,
causal_block_size = space_seq_len, causal_block_size = space_seq_len,
softclamp_value = self.attn_softclamp_value softclamp_value = self.attn_softclamp_value,
block_size_per_special = space_seq_len,
num_special_tokens = 1
) )
use_flex = tokens.is_cuda and exists(flex_attention) use_flex = tokens.is_cuda and exists(flex_attention)
attend_fn = get_attend_fn(use_flex, seq_len, seq_len) # encoder attend
# modality can only attend to itself while latents can attend to everything
# similar to agent token in dynamics model
encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
# encoder # encoder
for attn, ff in self.encoder_layers: for attn, ff in self.encoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
tokens = ff(tokens) + tokens tokens = ff(tokens) + tokens
tokens = self.encoder_norm(tokens) tokens = self.encoder_norm(tokens)
@ -930,10 +911,14 @@ class VideoTokenizer(Module):
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d') tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
# decoder attend
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
# decoder attention # decoder attention
for attn, ff in self.decoder_layers: for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
tokens = ff(tokens) + tokens tokens = ff(tokens) + tokens
@ -1172,7 +1157,7 @@ class DynamicsModel(Module):
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)

View File

@ -64,11 +64,15 @@ def test_symexp_two_hot():
@param('softclamp_value', (50., None)) @param('softclamp_value', (50., None))
@param('num_agent_tokens', (0, 1)) @param('num_agent_tokens', (0, 1))
@param('causal_block_size', (1, 8)) @param('causal_block_size', (1, 8))
@param('block_size_per_special', (1, 8))
@param('special_attend_only_itself', (False, True))
def test_attend_factory( def test_attend_factory(
causal, causal,
softclamp_value, softclamp_value,
num_agent_tokens, num_agent_tokens,
causal_block_size causal_block_size,
block_size_per_special,
special_attend_only_itself
): ):
from dreamer4.dreamer4 import get_attend_fn from dreamer4.dreamer4 import get_attend_fn
@ -84,7 +88,9 @@ def test_attend_factory(
causal_block_size = causal_block_size, causal_block_size = causal_block_size,
softclamp_value = softclamp_value, softclamp_value = softclamp_value,
device = q.device, device = q.device,
num_agent_tokens = num_agent_tokens num_agent_tokens = num_agent_tokens,
block_size_per_special = block_size_per_special,
special_attend_only_itself = special_attend_only_itself
) )
attend = get_attend_fn(True, **attend_kwargs) attend = get_attend_fn(True, **attend_kwargs)