diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index f545b2a..a337820 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -265,7 +265,7 @@ def flex_block_mask( # assume special tokens (either latent or agent tokens) are placed at the right hand side # so [modality] [latents | agent] - def create_mask(_, __, qi, ki): + def create_mask(b, __, qi, ki): q_block_index = qi // block_size k_block_index = ki // block_size @@ -274,7 +274,7 @@ def flex_block_mask( q_is_special = (qi % block_size) >= special_token_index_start k_is_special = (ki % block_size) >= special_token_index_start - mask = True + mask = b >= -1 # make shift True tensor if is_causal: mask &= q_block_index >= k_block_index @@ -285,7 +285,7 @@ def flex_block_mask( if prevent_special_to_modality: mask &= ~(~q_is_special & k_is_special) - return causal_mask + return mask block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) return block_mask @@ -312,6 +312,60 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None): return block_causal_mask[:seq_len, :seq_len] +# attend functions + +def naive_attend( + q, k, v, + softclamp_value = 50., + scale = None, + causal = False, + mask = None +): + + if not exists(scale): + scale = q.shape[-1] ** -0.5 + + # grouped query attention + + groups = q.shape[1] // k.shape[1] + + q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups) + + # similarity + + sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j') + + # softclamping a la gemma 3 + + if exists(softclamp_value): + sim = softclamp(sim, softclamp_value) + + # scale and attention + + sim = sim * scale + + # masking + + mask_value = -torch.finfo(sim.dtype).max + + if exists(mask): + sim = sim.masked_fill(~mask, mask_value) + + if causal: + i, j = sim.shape[-2:] + causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1) + sim = sim.masked_fill(causal_mask, mask_value) + + # attend + + attn = sim.softmax(dim = -1) + + # aggregate + + out = einsum(attn, v, 'b h g i j, b h j d -> b h i d') + + return out + # attention class Attention(Module): @@ -319,6 +373,7 @@ class Attention(Module): self, dim, dim_head = 64, + query_heads = None, heads = 8, softclamp_value = 50., pre_rmsnorm = True, @@ -327,14 +382,23 @@ class Attention(Module): super().__init__() self.norm = RMSNorm(dim) if pre_rmsnorm else Identity() + # setup grouped query attention + + query_heads = default(query_heads, heads) + assert query_heads >= heads and divisible_by(query_heads, heads) + + # scaling, splitting and merging of heads + self.scale = dim_head ** -0.5 - self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) + self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head) self.merge_heads = Rearrange('b h n d -> b n (h d)') - dim_inner = dim_head * heads - self.to_q = LinearNoBias(dim, dim_inner) - self.to_kv = LinearNoBias(dim, dim_inner * 2) - self.to_out = LinearNoBias(dim_inner, dim) + dim_q_inner = dim_head * query_heads + dim_kv_inner = dim_head * heads + + self.to_q = LinearNoBias(dim, dim_q_inner) + self.to_kv = LinearNoBias(dim, dim_kv_inner * 2) + self.to_out = LinearNoBias(dim_kv_inner, dim) # masking related @@ -342,7 +406,7 @@ class Attention(Module): # stability related - self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) + self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads) self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads) self.softclamp_value = softclamp_value @@ -376,38 +440,15 @@ class Attention(Module): k = cat((ck, k), dim = -2) v = cat((cv, v), dim = -2) - # similarity + # attention - sim = einsum(q, k, 'b h i d, b h j d -> b h i j') - - # softclamping a la gemma 3 - - if exists(self.softclamp_value): - sim = softclamp(sim, self.softclamp_value) - - # scale and attention - - sim = sim * self.scale - - # masking - - mask_value = -torch.finfo(sim.dtype).max - - if exists(mask): - sim = sim.masked_fill(~mask, mask_value) - - if self.causal: - i, j = sim.shape[-2:] - causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1) - sim = sim.masked_fill(causal_mask, mask_value) - - # attend - - attn = sim.softmax(dim = -1) - - # aggregate - - out = einsum(attn, v, 'b h i j, b h j d -> b h i d') + out = naive_attend( + q, k, v, + softclamp_value = self.softclamp_value, + scale = self.scale, + causal = self.causal, + mask = mask + ) # merge heads diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 3a7c2a8..8fc606a 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -3,8 +3,10 @@ param = pytest.mark.parametrize import torch @param('pred_orig_latent', (False, True)) +@param('gqa', (False, True)) def test_e2e( - pred_orig_latent + pred_orig_latent, + gqa ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel @@ -17,7 +19,19 @@ def test_e2e( latents = tokenizer(x, return_latents = True) assert latents.shape[-1] == 32 - dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent) + query_heads, heads = (16, 4) if gqa else (8, 8) + + dynamics = DynamicsModel( + 512, + dim_latent = 32, + num_signal_levels = 500, + num_step_sizes = 32, + pred_orig_latent = pred_orig_latent, + attn_kwargs = dict( + heads = heads, + query_heads = query_heads + ) + ) signal_levels = torch.randint(0, 500, (2, 4)) step_sizes = torch.randint(0, 32, (2, 4))