the dynamics model has a spatial attention with a non-causal attention pattern but nothing else attending to agent tokens
This commit is contained in:
parent
77ad96ded2
commit
58a6964dd9
@ -252,10 +252,11 @@ class MultiHeadRMSNorm(Module):
|
|||||||
# masking related
|
# masking related
|
||||||
# block causal mask (space fully attends within each block, while time is causal)
|
# block causal mask (space fully attends within each block, while time is causal)
|
||||||
|
|
||||||
def flex_block_causal_mask(
|
def flex_block_mask(
|
||||||
seq_len,
|
seq_len,
|
||||||
block_size,
|
block_size,
|
||||||
num_special_tokens = 0,
|
num_special_tokens = 0,
|
||||||
|
is_causal = True,
|
||||||
prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model
|
prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model
|
||||||
prevent_special_to_modality = False # decoder of tokenizer
|
prevent_special_to_modality = False # decoder of tokenizer
|
||||||
):
|
):
|
||||||
@ -273,13 +274,16 @@ def flex_block_causal_mask(
|
|||||||
q_is_special = (qi % block_size) >= special_token_index_start
|
q_is_special = (qi % block_size) >= special_token_index_start
|
||||||
k_is_special = (ki % block_size) >= special_token_index_start
|
k_is_special = (ki % block_size) >= special_token_index_start
|
||||||
|
|
||||||
causal_mask = q_block_index >= k_block_index
|
mask = True
|
||||||
|
|
||||||
|
if is_causal:
|
||||||
|
mask &= q_block_index >= k_block_index
|
||||||
|
|
||||||
if prevent_modality_to_special:
|
if prevent_modality_to_special:
|
||||||
causal_mask = causal_mask & ~(q_is_special & ~k_is_special)
|
mask &= ~(q_is_special & ~k_is_special)
|
||||||
|
|
||||||
if prevent_special_to_modality:
|
if prevent_special_to_modality:
|
||||||
causal_mask = causal_mask & ~(~q_is_special & k_is_special)
|
mask &= ~(~q_is_special & k_is_special)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user