grouped query attention is necessary

This commit is contained in:
lucidrains 2025-10-04 06:31:32 -07:00
parent 58a6964dd9
commit 8373cb13ec
2 changed files with 97 additions and 42 deletions

View File

@ -265,7 +265,7 @@ def flex_block_mask(
# assume special tokens (either latent or agent tokens) are placed at the right hand side
# so [modality] [latents | agent]
def create_mask(_, __, qi, ki):
def create_mask(b, __, qi, ki):
q_block_index = qi // block_size
k_block_index = ki // block_size
@ -274,7 +274,7 @@ def flex_block_mask(
q_is_special = (qi % block_size) >= special_token_index_start
k_is_special = (ki % block_size) >= special_token_index_start
mask = True
mask = b >= -1 # make shift True tensor
if is_causal:
mask &= q_block_index >= k_block_index
@ -285,7 +285,7 @@ def flex_block_mask(
if prevent_special_to_modality:
mask &= ~(~q_is_special & k_is_special)
return causal_mask
return mask
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask
@ -312,6 +312,60 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None):
return block_causal_mask[:seq_len, :seq_len]
# attend functions
def naive_attend(
q, k, v,
softclamp_value = 50.,
scale = None,
causal = False,
mask = None
):
if not exists(scale):
scale = q.shape[-1] ** -0.5
# grouped query attention
groups = q.shape[1] // k.shape[1]
q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
# similarity
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
# softclamping a la gemma 3
if exists(softclamp_value):
sim = softclamp(sim, softclamp_value)
# scale and attention
sim = sim * scale
# masking
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
sim = sim.masked_fill(~mask, mask_value)
if causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)
# attend
attn = sim.softmax(dim = -1)
# aggregate
out = einsum(attn, v, 'b h g i j, b h j d -> b h i d')
return out
# attention
class Attention(Module):
@ -319,6 +373,7 @@ class Attention(Module):
self,
dim,
dim_head = 64,
query_heads = None,
heads = 8,
softclamp_value = 50.,
pre_rmsnorm = True,
@ -327,14 +382,23 @@ class Attention(Module):
super().__init__()
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
# setup grouped query attention
query_heads = default(query_heads, heads)
assert query_heads >= heads and divisible_by(query_heads, heads)
# scaling, splitting and merging of heads
self.scale = dim_head ** -0.5
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
dim_inner = dim_head * heads
self.to_q = LinearNoBias(dim, dim_inner)
self.to_kv = LinearNoBias(dim, dim_inner * 2)
self.to_out = LinearNoBias(dim_inner, dim)
dim_q_inner = dim_head * query_heads
dim_kv_inner = dim_head * heads
self.to_q = LinearNoBias(dim, dim_q_inner)
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
self.to_out = LinearNoBias(dim_kv_inner, dim)
# masking related
@ -342,7 +406,7 @@ class Attention(Module):
# stability related
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
self.softclamp_value = softclamp_value
@ -376,38 +440,15 @@ class Attention(Module):
k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)
# similarity
# attention
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
# softclamping a la gemma 3
if exists(self.softclamp_value):
sim = softclamp(sim, self.softclamp_value)
# scale and attention
sim = sim * self.scale
# masking
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
sim = sim.masked_fill(~mask, mask_value)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)
# attend
attn = sim.softmax(dim = -1)
# aggregate
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
out = naive_attend(
q, k, v,
softclamp_value = self.softclamp_value,
scale = self.scale,
causal = self.causal,
mask = mask
)
# merge heads

View File

@ -3,8 +3,10 @@ param = pytest.mark.parametrize
import torch
@param('pred_orig_latent', (False, True))
@param('gqa', (False, True))
def test_e2e(
pred_orig_latent
pred_orig_latent,
gqa
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
@ -17,7 +19,19 @@ def test_e2e(
latents = tokenizer(x, return_latents = True)
assert latents.shape[-1] == 32
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent)
query_heads, heads = (16, 4) if gqa else (8, 8)
dynamics = DynamicsModel(
512,
dim_latent = 32,
num_signal_levels = 500,
num_step_sizes = 32,
pred_orig_latent = pred_orig_latent,
attn_kwargs = dict(
heads = heads,
query_heads = query_heads
)
)
signal_levels = torch.randint(0, 500, (2, 4))
step_sizes = torch.randint(0, 32, (2, 4))