complete all the types of attention masking patterns as proposed in the paper
This commit is contained in:
parent
5c6be4d979
commit
971637673b
@ -376,47 +376,6 @@ class MultiHeadRMSNorm(Module):
|
||||
scale = (self.gamma + 1.) * self.scale
|
||||
return einx.multiply('... h n d, h d', normed, scale)
|
||||
|
||||
# masking related
|
||||
# block causal mask (space fully attends within each block, while time is causal)
|
||||
|
||||
def flex_block_mask(
|
||||
seq_len,
|
||||
block_size,
|
||||
num_special_tokens = 0,
|
||||
is_causal = True,
|
||||
prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model
|
||||
prevent_special_to_modality = False # decoder of tokenizer
|
||||
):
|
||||
assert num_special_tokens <= block_size
|
||||
|
||||
# assume special tokens (either latent or agent tokens) are placed at the right hand side
|
||||
# so [modality] [latents | agent]
|
||||
|
||||
def create_mask(b, __, qi, ki):
|
||||
q_block_index = qi // block_size
|
||||
k_block_index = ki // block_size
|
||||
|
||||
special_token_index_start = block_size - num_special_tokens
|
||||
|
||||
q_is_special = (qi % block_size) >= special_token_index_start
|
||||
k_is_special = (ki % block_size) >= special_token_index_start
|
||||
|
||||
mask = b >= -1 # make shift True tensor
|
||||
|
||||
if is_causal:
|
||||
mask &= q_block_index >= k_block_index
|
||||
|
||||
if prevent_modality_to_special:
|
||||
mask &= ~(q_is_special & ~k_is_special)
|
||||
|
||||
if prevent_special_to_modality:
|
||||
mask &= ~(~q_is_special & k_is_special)
|
||||
|
||||
return mask
|
||||
|
||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||
return block_mask
|
||||
|
||||
# naive attend
|
||||
|
||||
def naive_attend(
|
||||
@ -493,20 +452,31 @@ def block_mask_causal(block_size):
|
||||
bq = q // block_size
|
||||
bk = k // block_size
|
||||
return bq >= bk
|
||||
|
||||
return inner
|
||||
|
||||
def agent_token_mask(q, k, seq_len, num_tokens):
|
||||
def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
|
||||
bq = q % seq_len
|
||||
bk = k % seq_len
|
||||
|
||||
is_special_start_index = seq_len - num_tokens
|
||||
|
||||
q_is_special = q >= is_special_start_index
|
||||
k_is_special = k >= is_special_start_index
|
||||
return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
||||
|
||||
if special_attend_only_itself:
|
||||
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
|
||||
else:
|
||||
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
||||
|
||||
return out
|
||||
|
||||
def block_mask_special_tokens_right(
|
||||
seq_len,
|
||||
num_tokens
|
||||
):
|
||||
def inner(b, h, q, k):
|
||||
return agent_token_mask(q, k, seq_len, num_tokens)
|
||||
return special_token_mask(q, k, seq_len, num_tokens)
|
||||
return inner
|
||||
|
||||
def compose_mask(mask1, mask2):
|
||||
@ -539,17 +509,21 @@ def get_attend_fn(
|
||||
causal = False,
|
||||
causal_block_size = 1,
|
||||
softclamp_value = 50.,
|
||||
num_agent_tokens = 0,
|
||||
num_special_tokens = 0, # special tokens are latents / agents
|
||||
block_size_per_special = None, # defaults to k_seq_len
|
||||
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
|
||||
device = None
|
||||
):
|
||||
block_size_per_special = default(block_size_per_special, k_seq_len)
|
||||
|
||||
if use_flex:
|
||||
# flex pathway
|
||||
|
||||
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
|
||||
|
||||
if num_agent_tokens > 0:
|
||||
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
|
||||
block_mask_fn = compose_mask(block_mask_fn, agent_block_mask)
|
||||
if num_special_tokens > 0:
|
||||
special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
|
||||
block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
|
||||
|
||||
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
|
||||
|
||||
@ -559,11 +533,11 @@ def get_attend_fn(
|
||||
# naive pathway
|
||||
|
||||
mask = None
|
||||
if num_agent_tokens > 0:
|
||||
if num_special_tokens > 0:
|
||||
q_seq = torch.arange(seq_len, device = device)[:, None]
|
||||
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
||||
|
||||
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
|
||||
mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
|
||||
|
||||
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
|
||||
|
||||
@ -890,17 +864,24 @@ class VideoTokenizer(Module):
|
||||
attend_kwargs = dict(
|
||||
causal = True,
|
||||
causal_block_size = space_seq_len,
|
||||
softclamp_value = self.attn_softclamp_value
|
||||
softclamp_value = self.attn_softclamp_value,
|
||||
block_size_per_special = space_seq_len,
|
||||
num_special_tokens = 1
|
||||
)
|
||||
|
||||
use_flex = tokens.is_cuda and exists(flex_attention)
|
||||
|
||||
attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
|
||||
# encoder attend
|
||||
|
||||
# modality can only attend to itself while latents can attend to everything
|
||||
# similar to agent token in dynamics model
|
||||
|
||||
encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
|
||||
|
||||
# encoder
|
||||
|
||||
for attn, ff in self.encoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.encoder_norm(tokens)
|
||||
@ -930,10 +911,14 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
|
||||
|
||||
# decoder attend
|
||||
|
||||
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
|
||||
|
||||
# decoder attention
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
@ -1172,7 +1157,7 @@ class DynamicsModel(Module):
|
||||
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
|
||||
@ -64,11 +64,15 @@ def test_symexp_two_hot():
|
||||
@param('softclamp_value', (50., None))
|
||||
@param('num_agent_tokens', (0, 1))
|
||||
@param('causal_block_size', (1, 8))
|
||||
@param('block_size_per_special', (1, 8))
|
||||
@param('special_attend_only_itself', (False, True))
|
||||
def test_attend_factory(
|
||||
causal,
|
||||
softclamp_value,
|
||||
num_agent_tokens,
|
||||
causal_block_size
|
||||
causal_block_size,
|
||||
block_size_per_special,
|
||||
special_attend_only_itself
|
||||
):
|
||||
|
||||
from dreamer4.dreamer4 import get_attend_fn
|
||||
@ -84,7 +88,9 @@ def test_attend_factory(
|
||||
causal_block_size = causal_block_size,
|
||||
softclamp_value = softclamp_value,
|
||||
device = q.device,
|
||||
num_agent_tokens = num_agent_tokens
|
||||
num_agent_tokens = num_agent_tokens,
|
||||
block_size_per_special = block_size_per_special,
|
||||
special_attend_only_itself = special_attend_only_itself
|
||||
)
|
||||
|
||||
attend = get_attend_fn(True, **attend_kwargs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user