diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 725f167..68a5a08 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -376,47 +376,6 @@ class MultiHeadRMSNorm(Module): scale = (self.gamma + 1.) * self.scale return einx.multiply('... h n d, h d', normed, scale) -# masking related -# block causal mask (space fully attends within each block, while time is causal) - -def flex_block_mask( - seq_len, - block_size, - num_special_tokens = 0, - is_causal = True, - prevent_modality_to_special = False, # encoder of tokenizer as well as (perhaps crucially) the dynamics model - prevent_special_to_modality = False # decoder of tokenizer -): - assert num_special_tokens <= block_size - - # assume special tokens (either latent or agent tokens) are placed at the right hand side - # so [modality] [latents | agent] - - def create_mask(b, __, qi, ki): - q_block_index = qi // block_size - k_block_index = ki // block_size - - special_token_index_start = block_size - num_special_tokens - - q_is_special = (qi % block_size) >= special_token_index_start - k_is_special = (ki % block_size) >= special_token_index_start - - mask = b >= -1 # make shift True tensor - - if is_causal: - mask &= q_block_index >= k_block_index - - if prevent_modality_to_special: - mask &= ~(q_is_special & ~k_is_special) - - if prevent_special_to_modality: - mask &= ~(~q_is_special & k_is_special) - - return mask - - block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) - return block_mask - # naive attend def naive_attend( @@ -493,20 +452,31 @@ def block_mask_causal(block_size): bq = q // block_size bk = k // block_size return bq >= bk + return inner -def agent_token_mask(q, k, seq_len, num_tokens): +def special_token_mask(q, k, seq_len, num_tokens, special_attend_only_itself = False): + bq = q % seq_len + bk = k % seq_len + is_special_start_index = seq_len - num_tokens + q_is_special = q >= is_special_start_index k_is_special = k >= is_special_start_index - return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens + + if special_attend_only_itself: + out = ~(q_is_special & ~k_is_special) # modality attends to everything, but latent can only attend to itself (proposed attention pattern for encoder of video tokenizer) + else: + out = ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens + + return out def block_mask_special_tokens_right( seq_len, num_tokens ): def inner(b, h, q, k): - return agent_token_mask(q, k, seq_len, num_tokens) + return special_token_mask(q, k, seq_len, num_tokens) return inner def compose_mask(mask1, mask2): @@ -539,17 +509,21 @@ def get_attend_fn( causal = False, causal_block_size = 1, softclamp_value = 50., - num_agent_tokens = 0, + num_special_tokens = 0, # special tokens are latents / agents + block_size_per_special = None, # defaults to k_seq_len + special_attend_only_itself = False, # by default, modality only attends to itself while special sees everything, but if turned True, will be the inverse - special can only attend to itself but modality can attend everything device = None ): + block_size_per_special = default(block_size_per_special, k_seq_len) + if use_flex: # flex pathway block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop - if num_agent_tokens > 0: - agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens) - block_mask_fn = compose_mask(block_mask_fn, agent_block_mask) + if num_special_tokens > 0: + special_block_mask = block_mask_special_tokens_right(block_size_per_special, num_special_tokens, special_attend_only_itself) + block_mask_fn = compose_mask(block_mask_fn, special_block_mask) block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len) @@ -559,11 +533,11 @@ def get_attend_fn( # naive pathway mask = None - if num_agent_tokens > 0: + if num_special_tokens > 0: q_seq = torch.arange(seq_len, device = device)[:, None] k_seq = torch.arange(k_seq_len, device = device)[None, :] - mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens) + mask = special_token_mask(q_seq, k_seq, block_size_per_special, num_special_tokens, special_attend_only_itself) attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value) @@ -890,17 +864,24 @@ class VideoTokenizer(Module): attend_kwargs = dict( causal = True, causal_block_size = space_seq_len, - softclamp_value = self.attn_softclamp_value + softclamp_value = self.attn_softclamp_value, + block_size_per_special = space_seq_len, + num_special_tokens = 1 ) use_flex = tokens.is_cuda and exists(flex_attention) - attend_fn = get_attend_fn(use_flex, seq_len, seq_len) + # encoder attend + + # modality can only attend to itself while latents can attend to everything + # similar to agent token in dynamics model + + encoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True) # encoder for attn, ff in self.encoder_layers: - tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens + tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = encoder_attend_fn) + tokens tokens = ff(tokens) + tokens tokens = self.encoder_norm(tokens) @@ -930,10 +911,14 @@ class VideoTokenizer(Module): tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d') + # decoder attend + + decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len) + # decoder attention for attn, ff in self.decoder_layers: - tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = attend_fn) + tokens + tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens tokens = ff(tokens) + tokens @@ -1172,7 +1157,7 @@ class DynamicsModel(Module): attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) - space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality + space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index be28f06..6b22786 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -64,11 +64,15 @@ def test_symexp_two_hot(): @param('softclamp_value', (50., None)) @param('num_agent_tokens', (0, 1)) @param('causal_block_size', (1, 8)) +@param('block_size_per_special', (1, 8)) +@param('special_attend_only_itself', (False, True)) def test_attend_factory( causal, softclamp_value, num_agent_tokens, - causal_block_size + causal_block_size, + block_size_per_special, + special_attend_only_itself ): from dreamer4.dreamer4 import get_attend_fn @@ -84,7 +88,9 @@ def test_attend_factory( causal_block_size = causal_block_size, softclamp_value = softclamp_value, device = q.device, - num_agent_tokens = num_agent_tokens + num_agent_tokens = num_agent_tokens, + block_size_per_special = block_size_per_special, + special_attend_only_itself = special_attend_only_itself ) attend = get_attend_fn(True, **attend_kwargs)