grouped query attention is necessary
This commit is contained in:
parent
58a6964dd9
commit
8373cb13ec
@ -265,7 +265,7 @@ def flex_block_mask(
|
||||
# assume special tokens (either latent or agent tokens) are placed at the right hand side
|
||||
# so [modality] [latents | agent]
|
||||
|
||||
def create_mask(_, __, qi, ki):
|
||||
def create_mask(b, __, qi, ki):
|
||||
q_block_index = qi // block_size
|
||||
k_block_index = ki // block_size
|
||||
|
||||
@ -274,7 +274,7 @@ def flex_block_mask(
|
||||
q_is_special = (qi % block_size) >= special_token_index_start
|
||||
k_is_special = (ki % block_size) >= special_token_index_start
|
||||
|
||||
mask = True
|
||||
mask = b >= -1 # make shift True tensor
|
||||
|
||||
if is_causal:
|
||||
mask &= q_block_index >= k_block_index
|
||||
@ -285,7 +285,7 @@ def flex_block_mask(
|
||||
if prevent_special_to_modality:
|
||||
mask &= ~(~q_is_special & k_is_special)
|
||||
|
||||
return causal_mask
|
||||
return mask
|
||||
|
||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||
return block_mask
|
||||
@ -312,6 +312,60 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||
|
||||
return block_causal_mask[:seq_len, :seq_len]
|
||||
|
||||
# attend functions
|
||||
|
||||
def naive_attend(
|
||||
q, k, v,
|
||||
softclamp_value = 50.,
|
||||
scale = None,
|
||||
causal = False,
|
||||
mask = None
|
||||
):
|
||||
|
||||
if not exists(scale):
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
# grouped query attention
|
||||
|
||||
groups = q.shape[1] // k.shape[1]
|
||||
|
||||
q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
|
||||
|
||||
# softclamping a la gemma 3
|
||||
|
||||
if exists(softclamp_value):
|
||||
sim = softclamp(sim, softclamp_value)
|
||||
|
||||
# scale and attention
|
||||
|
||||
sim = sim * scale
|
||||
|
||||
# masking
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
sim = sim.masked_fill(~mask, mask_value)
|
||||
|
||||
if causal:
|
||||
i, j = sim.shape[-2:]
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, mask_value)
|
||||
|
||||
# attend
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum(attn, v, 'b h g i j, b h j d -> b h i d')
|
||||
|
||||
return out
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(Module):
|
||||
@ -319,6 +373,7 @@ class Attention(Module):
|
||||
self,
|
||||
dim,
|
||||
dim_head = 64,
|
||||
query_heads = None,
|
||||
heads = 8,
|
||||
softclamp_value = 50.,
|
||||
pre_rmsnorm = True,
|
||||
@ -327,14 +382,23 @@ class Attention(Module):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||
|
||||
# setup grouped query attention
|
||||
|
||||
query_heads = default(query_heads, heads)
|
||||
assert query_heads >= heads and divisible_by(query_heads, heads)
|
||||
|
||||
# scaling, splitting and merging of heads
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||
|
||||
dim_inner = dim_head * heads
|
||||
self.to_q = LinearNoBias(dim, dim_inner)
|
||||
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
||||
self.to_out = LinearNoBias(dim_inner, dim)
|
||||
dim_q_inner = dim_head * query_heads
|
||||
dim_kv_inner = dim_head * heads
|
||||
|
||||
self.to_q = LinearNoBias(dim, dim_q_inner)
|
||||
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
|
||||
self.to_out = LinearNoBias(dim_kv_inner, dim)
|
||||
|
||||
# masking related
|
||||
|
||||
@ -342,7 +406,7 @@ class Attention(Module):
|
||||
|
||||
# stability related
|
||||
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||
|
||||
self.softclamp_value = softclamp_value
|
||||
@ -376,38 +440,15 @@ class Attention(Module):
|
||||
k = cat((ck, k), dim = -2)
|
||||
v = cat((cv, v), dim = -2)
|
||||
|
||||
# similarity
|
||||
# attention
|
||||
|
||||
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
||||
|
||||
# softclamping a la gemma 3
|
||||
|
||||
if exists(self.softclamp_value):
|
||||
sim = softclamp(sim, self.softclamp_value)
|
||||
|
||||
# scale and attention
|
||||
|
||||
sim = sim * self.scale
|
||||
|
||||
# masking
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
sim = sim.masked_fill(~mask, mask_value)
|
||||
|
||||
if self.causal:
|
||||
i, j = sim.shape[-2:]
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, mask_value)
|
||||
|
||||
# attend
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
||||
out = naive_attend(
|
||||
q, k, v,
|
||||
softclamp_value = self.softclamp_value,
|
||||
scale = self.scale,
|
||||
causal = self.causal,
|
||||
mask = mask
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
|
||||
@ -3,8 +3,10 @@ param = pytest.mark.parametrize
|
||||
import torch
|
||||
|
||||
@param('pred_orig_latent', (False, True))
|
||||
@param('gqa', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent
|
||||
pred_orig_latent,
|
||||
gqa
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||
|
||||
@ -17,7 +19,19 @@ def test_e2e(
|
||||
latents = tokenizer(x, return_latents = True)
|
||||
assert latents.shape[-1] == 32
|
||||
|
||||
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent)
|
||||
query_heads, heads = (16, 4) if gqa else (8, 8)
|
||||
|
||||
dynamics = DynamicsModel(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
num_signal_levels = 500,
|
||||
num_step_sizes = 32,
|
||||
pred_orig_latent = pred_orig_latent,
|
||||
attn_kwargs = dict(
|
||||
heads = heads,
|
||||
query_heads = query_heads
|
||||
)
|
||||
)
|
||||
|
||||
signal_levels = torch.randint(0, 500, (2, 4))
|
||||
step_sizes = torch.randint(0, 32, (2, 4))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user