complete all the types of attention masking patterns as proposed in the paper
This commit is contained in:
parent
5c6be4d979
commit
971637673b
@ -376,47 +376,6 @@ class MultiHeadRMSNorm(Module):
|
|||||||
scale = (self.gamma + 1.) * self.scale
|
scale = (self.gamma + 1.) * self.scale
|
||||||
return einx.multiply('... h n d, h d', normed, scale)
|
return einx.multiply('... h n d, h d', normed, scale)
|
||||||
|
|
||||||
# masking related
|
|
||||||
# block causal mask (space fully attends within each block, while time is causal)
|
|
||||||
|
|
||||||
def flex_block_mask(
|
|
||||||
seq_len,
|
|
||||||
block_size,
|
|
||||||
num_special_tokens = 0,
|
|
||||||
is_causal = True,
|
|
||||||
prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model
|
|
||||||
prevent_special_to_modality = False # decoder of tokenizer
|
|
||||||
):
|
|
||||||
assert num_special_tokens <= block_size
|
|
||||||
|
|
||||||
# assume special tokens (either latent or agent tokens) are placed at the right hand side
|
|
||||||
# so [modality] [latents | agent]
|
|
||||||
|
|
||||||
def create_mask(b, __, qi, ki):
|
|
||||||
q_block_index = qi // block_size
|
|
||||||
k_block_index = ki // block_size
|
|
||||||
|
|
||||||
special_token_index_start = block_size - num_special_tokens
|
|
||||||
|
|
||||||
q_is_special = (qi % block_size) >= special_token_index_start
|
|
||||||
k_is_special = (ki % block_size) >= special_token_index_start
|
|
||||||
|
|
||||||
mask = b >= -1 # make shift True tensor
|
|
||||||
|
|
||||||
if is_causal:
|
|
||||||
mask &= q_block_index >= k_block_index
|
|
||||||
|
|
||||||
if prevent_modality_to_special:
|
|
||||||
mask &= ~(q_is_special & ~k_is_special)
|
|
||||||
|
|
||||||
if prevent_special_to_modality:
|
|
||||||
mask &= ~(~q_is_special & k_is_special)
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
|
||||||
return block_mask
|
|
||||||
|
|
||||||
# naive attend
|
# naive attend
|
||||||
|
|
||||||
def naive_attend(
|
def naive_attend(
|
||||||
@ -493,20 +452,31 @@ def block_mask_causal(block_size):
|
|||||||
bq = q // block_size
|
bq = q // block_size
|
||||||
bk = k // block_size
|
bk = k // block_size
|
||||||
return bq >= bk
|
return bq >= bk
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def agent_token_mask(q, k, seq_len, num_tokens):
|
def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False):
|
||||||
|
bq = q % seq_len
|
||||||
|
bk = k % seq_len
|
||||||
|
|
||||||
is_special_start_index = seq_len - num_tokens
|
is_special_start_index = seq_len - num_tokens
|
||||||
|
|
||||||
q_is_special = q >= is_special_start_index
|
q_is_special = q >= is_special_start_index
|
||||||
k_is_special = k >= is_special_start_index
|
k_is_special = k >= is_special_start_index
|
||||||
return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
|
||||||
|
if special_attend_only_itself:
|
||||||
|
out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer)
|
||||||
|
else:
|
||||||
|
out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
def block_mask_special_tokens_right(
|
def block_mask_special_tokens_right(
|
||||||
seq_len,
|
seq_len,
|
||||||
num_tokens
|
num_tokens
|
||||||
):
|
):
|
||||||
def inner(b, h, q, k):
|
def inner(b, h, q, k):
|
||||||
return agent_token_mask(q, k, seq_len, num_tokens)
|
return special_token_mask(q, k, seq_len, num_tokens)
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def compose_mask(mask1, mask2):
|
def compose_mask(mask1, mask2):
|
||||||
@ -539,17 +509,21 @@ def get_attend_fn(
|
|||||||
causal = False,
|
causal = False,
|
||||||
causal_block_size = 1,
|
causal_block_size = 1,
|
||||||
softclamp_value = 50.,
|
softclamp_value = 50.,
|
||||||
num_agent_tokens = 0,
|
num_special_tokens = 0, # special tokens are latents / agents
|
||||||
|
block_size_per_special = None, # defaults to k_seq_len
|
||||||
|
special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything
|
||||||
device = None
|
device = None
|
||||||
):
|
):
|
||||||
|
block_size_per_special = default(block_size_per_special, k_seq_len)
|
||||||
|
|
||||||
if use_flex:
|
if use_flex:
|
||||||
# flex pathway
|
# flex pathway
|
||||||
|
|
||||||
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
|
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
|
||||||
|
|
||||||
if num_agent_tokens > 0:
|
if num_special_tokens > 0:
|
||||||
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
|
special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself)
|
||||||
block_mask_fn = compose_mask(block_mask_fn, agent_block_mask)
|
block_mask_fn = compose_mask(block_mask_fn, special_block_mask)
|
||||||
|
|
||||||
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
|
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
|
||||||
|
|
||||||
@ -559,11 +533,11 @@ def get_attend_fn(
|
|||||||
# naive pathway
|
# naive pathway
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if num_agent_tokens > 0:
|
if num_special_tokens > 0:
|
||||||
q_seq = torch.arange(seq_len, device = device)[:, None]
|
q_seq = torch.arange(seq_len, device = device)[:, None]
|
||||||
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
||||||
|
|
||||||
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
|
mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself)
|
||||||
|
|
||||||
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
|
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
|
||||||
|
|
||||||
@ -890,17 +864,24 @@ class VideoTokenizer(Module):
|
|||||||
attend_kwargs = dict(
|
attend_kwargs = dict(
|
||||||
causal = True,
|
causal = True,
|
||||||
causal_block_size = space_seq_len,
|
causal_block_size = space_seq_len,
|
||||||
softclamp_value = self.attn_softclamp_value
|
softclamp_value = self.attn_softclamp_value,
|
||||||
|
block_size_per_special = space_seq_len,
|
||||||
|
num_special_tokens = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
use_flex = tokens.is_cuda and exists(flex_attention)
|
use_flex = tokens.is_cuda and exists(flex_attention)
|
||||||
|
|
||||||
attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
|
# encoder attend
|
||||||
|
|
||||||
|
# modality can only attend to itself while latents can attend to everything
|
||||||
|
# similar to agent token in dynamics model
|
||||||
|
|
||||||
|
encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
|
|
||||||
for attn, ff in self.encoder_layers:
|
for attn, ff in self.encoder_layers:
|
||||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens
|
||||||
tokens = ff(tokens) + tokens
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
tokens = self.encoder_norm(tokens)
|
tokens = self.encoder_norm(tokens)
|
||||||
@ -930,10 +911,14 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
|
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
|
||||||
|
|
||||||
|
# decoder attend
|
||||||
|
|
||||||
|
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
|
||||||
|
|
||||||
# decoder attention
|
# decoder attention
|
||||||
|
|
||||||
for attn, ff in self.decoder_layers:
|
for attn, ff in self.decoder_layers:
|
||||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
||||||
|
|
||||||
tokens = ff(tokens) + tokens
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
@ -1172,7 +1157,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
||||||
|
|
||||||
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||||
|
|
||||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -64,11 +64,15 @@ def test_symexp_two_hot():
|
|||||||
@param('softclamp_value', (50., None))
|
@param('softclamp_value', (50., None))
|
||||||
@param('num_agent_tokens', (0, 1))
|
@param('num_agent_tokens', (0, 1))
|
||||||
@param('causal_block_size', (1, 8))
|
@param('causal_block_size', (1, 8))
|
||||||
|
@param('block_size_per_special', (1, 8))
|
||||||
|
@param('special_attend_only_itself', (False, True))
|
||||||
def test_attend_factory(
|
def test_attend_factory(
|
||||||
causal,
|
causal,
|
||||||
softclamp_value,
|
softclamp_value,
|
||||||
num_agent_tokens,
|
num_agent_tokens,
|
||||||
causal_block_size
|
causal_block_size,
|
||||||
|
block_size_per_special,
|
||||||
|
special_attend_only_itself
|
||||||
):
|
):
|
||||||
|
|
||||||
from dreamer4.dreamer4 import get_attend_fn
|
from dreamer4.dreamer4 import get_attend_fn
|
||||||
@ -84,7 +88,9 @@ def test_attend_factory(
|
|||||||
causal_block_size = causal_block_size,
|
causal_block_size = causal_block_size,
|
||||||
softclamp_value = softclamp_value,
|
softclamp_value = softclamp_value,
|
||||||
device = q.device,
|
device = q.device,
|
||||||
num_agent_tokens = num_agent_tokens
|
num_agent_tokens = num_agent_tokens,
|
||||||
|
block_size_per_special = block_size_per_special,
|
||||||
|
special_attend_only_itself = special_attend_only_itself
|
||||||
)
|
)
|
||||||
|
|
||||||
attend = get_attend_fn(True, **attend_kwargs)
|
attend = get_attend_fn(True, **attend_kwargs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user