diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index ef4b058..0ab6a76 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -416,29 +416,7 @@ def flex_block_mask( block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) return block_mask -# for softclamping with flex attention - -def softclamp_score_mod(value): - - def inner(attn_logits, b, h, qi, ki): - attn_logits = attn_logits / value - attn_logits = torch.tanh(attn_logits) - attn_logits = attn_logits * value - return attn_logits - - return inner - -# todo - reuse the inner function from flex attn above with broadcasting - -def block_causal_mask(seq_len, block_size, device = None): - blocks = ceil(seq_len / block_size) - - causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril() - block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size2 = block_size) - - return block_causal_mask[:seq_len, :seq_len] - -# attend functions +# naive attend def naive_attend( q, k, v, @@ -461,15 +439,15 @@ def naive_attend( sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j') + # scale and attention + + sim = sim * scale + # softclamping a la gemma 3 if exists(softclamp_value): sim = softclamp(sim, softclamp_value) - # scale and attention - - sim = sim * scale - # masking mask_value = -torch.finfo(sim.dtype).max @@ -488,9 +466,89 @@ def naive_attend( # aggregate - out = einsum(attn, v, 'b h g i j, b h j d -> b h i d') + out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d') - return out + # merge the groups + + return rearrange(out, 'b h g i d -> b (h g) i d') + +# flex attention related and factory function for attend depending on whether on cuda + flex attention available + +def block_mask_causal(b, h, q, k): + return q >= k + +def agent_token_mask(q, k, seq_len, num_tokens): + is_special_start_index = seq_len - num_tokens + q_is_special = q >= is_special_start_index + k_is_special = k >= is_special_start_index + return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens + +def block_mask_special_tokens_right( + seq_len, + num_tokens +): + def inner(b, h, q, k): + return agent_token_mask(q, k, seq_len, num_tokens) + return inner + +def compose_mask(mask1, mask2): + def inner(b, h, q, k): + return mask1(b, h, q, k) & mask2(b, h, q, k) + + return inner + +def block_mask_noop(b, h, q, k): + return b >= 0 + +def score_mod_softclamp(value): + def inner(sim, b, h, q, k): + if not exists(value): + return sim + + sim = sim / value + sim = torch.tanh(sim) + sim = sim * value + return sim + + return inner + +# factory for attend function + +def get_attend_fn( + use_flex, + seq_len, + k_seq_len, + causal = False, + softclamp_value = 50., + num_agent_tokens = 0, + device = None +): + if use_flex: + # flex pathway + + block_mask_fn = block_mask_causal if causal else block_mask_noop + + if num_agent_tokens > 0: + agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens) + block_mask_fn = compose_mask(block_mask_fn, agent_block_mask) + + block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len) + + score_mod = score_mod_softclamp(softclamp_value) + attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True) + else: + # naive pathway + + mask = None + if num_agent_tokens > 0: + q_seq = torch.arange(seq_len, device = device)[:, None] + k_seq = torch.arange(k_seq_len, device = device)[None, :] + + mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens) + + attend_fn = partial(naive_attend, causal = causal, mask = mask, softclamp_value = softclamp_value) + + return attend_fn # attention @@ -521,7 +579,7 @@ class Attention(Module): self.to_q = LinearNoBias(dim, dim_q_inner) self.to_kv = LinearNoBias(dim, dim_kv_inner * 2) - self.to_out = LinearNoBias(dim_kv_inner, dim) + self.to_out = LinearNoBias(dim_q_inner, dim) # stability related @@ -949,6 +1007,15 @@ class DynamicsModel(Module): self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) + # calculate "space" seq len + + self.space_seq_len = ( + 1 # action / agent token + + 1 # signal + step + + num_register_tokens + + num_spatial_tokens + ) + # attention self.attn_softclamp_value = attn_softclamp_value @@ -1013,7 +1080,7 @@ class DynamicsModel(Module): latents = self.video_tokenizer.tokenize(video) - time = latents.shape[1] + time, device = latents.shape[1], latents.device # flow related @@ -1070,11 +1137,15 @@ class DynamicsModel(Module): # attend functions for space and time - attend_kwargs = dict(softclamp_value = self.attn_softclamp_value) + seq_len = tokens.shape[1] - space_attend = partial(naive_attend, causal = False, **attend_kwargs) + use_flex = exists(flex_attention) and tokens.is_cuda - time_attend = partial(naive_attend, causal = True, **attend_kwargs) + attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) + + space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality + + time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) # rotary diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index c14a6d5..0a7a2e3 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -58,3 +58,29 @@ def test_symexp_two_hot(): recon_values = two_hot_encoder.logits_to_scalar_value(encoded) assert torch.allclose(recon_values, values, atol = 1e-6) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda') +@param('causal', (False, True)) +@param('softclamp_value', (50., None)) +@param('num_agent_tokens', (0, 1)) +def test_attend_factory( + causal, + softclamp_value, + num_agent_tokens +): + + from dreamer4.dreamer4 import get_attend_fn + + q = torch.randn(1, 8, 1024, 512).cuda() + k = torch.randn(1, 4, 1024, 512).cuda() + v = torch.randn(1, 4, 1024, 512).cuda() + + attend_kwargs = dict(seq_len = 1024, k_seq_len = 1024, causal = causal, softclamp_value = softclamp_value, device = q.device, num_agent_tokens = num_agent_tokens) + + attend = get_attend_fn(True, **attend_kwargs) + flex_out = attend(q, k, v) + + attend = get_attend_fn(False, **attend_kwargs) + out = attend(q, k, v) + + assert torch.allclose(flex_out, out, atol = 1e-6)