grouped query attention is necessary
This commit is contained in:
parent
58a6964dd9
commit
8373cb13ec
@ -265,7 +265,7 @@ def flex_block_mask(
|
|||||||
# assume special tokens (either latent or agent tokens) are placed at the right hand side
|
# assume special tokens (either latent or agent tokens) are placed at the right hand side
|
||||||
# so [modality] [latents | agent]
|
# so [modality] [latents | agent]
|
||||||
|
|
||||||
def create_mask(_, __, qi, ki):
|
def create_mask(b, __, qi, ki):
|
||||||
q_block_index = qi // block_size
|
q_block_index = qi // block_size
|
||||||
k_block_index = ki // block_size
|
k_block_index = ki // block_size
|
||||||
|
|
||||||
@ -274,7 +274,7 @@ def flex_block_mask(
|
|||||||
q_is_special = (qi % block_size) >= special_token_index_start
|
q_is_special = (qi % block_size) >= special_token_index_start
|
||||||
k_is_special = (ki % block_size) >= special_token_index_start
|
k_is_special = (ki % block_size) >= special_token_index_start
|
||||||
|
|
||||||
mask = True
|
mask = b >= -1 # make shift True tensor
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
mask &= q_block_index >= k_block_index
|
mask &= q_block_index >= k_block_index
|
||||||
@ -285,7 +285,7 @@ def flex_block_mask(
|
|||||||
if prevent_special_to_modality:
|
if prevent_special_to_modality:
|
||||||
mask &= ~(~q_is_special & k_is_special)
|
mask &= ~(~q_is_special & k_is_special)
|
||||||
|
|
||||||
return causal_mask
|
return mask
|
||||||
|
|
||||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||||
return block_mask
|
return block_mask
|
||||||
@ -312,6 +312,60 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
|||||||
|
|
||||||
return block_causal_mask[:seq_len, :seq_len]
|
return block_causal_mask[:seq_len, :seq_len]
|
||||||
|
|
||||||
|
# attend functions
|
||||||
|
|
||||||
|
def naive_attend(
|
||||||
|
q, k, v,
|
||||||
|
softclamp_value = 50.,
|
||||||
|
scale = None,
|
||||||
|
causal = False,
|
||||||
|
mask = None
|
||||||
|
):
|
||||||
|
|
||||||
|
if not exists(scale):
|
||||||
|
scale = q.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
# grouped query attention
|
||||||
|
|
||||||
|
groups = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
|
q = rearrange(q, 'b (h g) ... -> b h g ...', g = groups)
|
||||||
|
|
||||||
|
# similarity
|
||||||
|
|
||||||
|
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
|
||||||
|
|
||||||
|
# softclamping a la gemma 3
|
||||||
|
|
||||||
|
if exists(softclamp_value):
|
||||||
|
sim = softclamp(sim, softclamp_value)
|
||||||
|
|
||||||
|
# scale and attention
|
||||||
|
|
||||||
|
sim = sim * scale
|
||||||
|
|
||||||
|
# masking
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
sim = sim.masked_fill(~mask, mask_value)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
i, j = sim.shape[-2:]
|
||||||
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
||||||
|
sim = sim.masked_fill(causal_mask, mask_value)
|
||||||
|
|
||||||
|
# attend
|
||||||
|
|
||||||
|
attn = sim.softmax(dim = -1)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum(attn, v, 'b h g i j, b h j d -> b h i d')
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
class Attention(Module):
|
class Attention(Module):
|
||||||
@ -319,6 +373,7 @@ class Attention(Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
|
query_heads = None,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
softclamp_value = 50.,
|
softclamp_value = 50.,
|
||||||
pre_rmsnorm = True,
|
pre_rmsnorm = True,
|
||||||
@ -327,14 +382,23 @@ class Attention(Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
self.norm = RMSNorm(dim) if pre_rmsnorm else Identity()
|
||||||
|
|
||||||
|
# setup grouped query attention
|
||||||
|
|
||||||
|
query_heads = default(query_heads, heads)
|
||||||
|
assert query_heads >= heads and divisible_by(query_heads, heads)
|
||||||
|
|
||||||
|
# scaling, splitting and merging of heads
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
self.split_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
||||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||||
|
|
||||||
dim_inner = dim_head * heads
|
dim_q_inner = dim_head * query_heads
|
||||||
self.to_q = LinearNoBias(dim, dim_inner)
|
dim_kv_inner = dim_head * heads
|
||||||
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
|
||||||
self.to_out = LinearNoBias(dim_inner, dim)
|
self.to_q = LinearNoBias(dim, dim_q_inner)
|
||||||
|
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
|
||||||
|
self.to_out = LinearNoBias(dim_kv_inner, dim)
|
||||||
|
|
||||||
# masking related
|
# masking related
|
||||||
|
|
||||||
@ -342,7 +406,7 @@ class Attention(Module):
|
|||||||
|
|
||||||
# stability related
|
# stability related
|
||||||
|
|
||||||
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = query_heads)
|
||||||
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||||
|
|
||||||
self.softclamp_value = softclamp_value
|
self.softclamp_value = softclamp_value
|
||||||
@ -376,38 +440,15 @@ class Attention(Module):
|
|||||||
k = cat((ck, k), dim = -2)
|
k = cat((ck, k), dim = -2)
|
||||||
v = cat((cv, v), dim = -2)
|
v = cat((cv, v), dim = -2)
|
||||||
|
|
||||||
# similarity
|
# attention
|
||||||
|
|
||||||
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
out = naive_attend(
|
||||||
|
q, k, v,
|
||||||
# softclamping a la gemma 3
|
softclamp_value = self.softclamp_value,
|
||||||
|
scale = self.scale,
|
||||||
if exists(self.softclamp_value):
|
causal = self.causal,
|
||||||
sim = softclamp(sim, self.softclamp_value)
|
mask = mask
|
||||||
|
)
|
||||||
# scale and attention
|
|
||||||
|
|
||||||
sim = sim * self.scale
|
|
||||||
|
|
||||||
# masking
|
|
||||||
|
|
||||||
mask_value = -torch.finfo(sim.dtype).max
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
sim = sim.masked_fill(~mask, mask_value)
|
|
||||||
|
|
||||||
if self.causal:
|
|
||||||
i, j = sim.shape[-2:]
|
|
||||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
|
||||||
sim = sim.masked_fill(causal_mask, mask_value)
|
|
||||||
|
|
||||||
# attend
|
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1)
|
|
||||||
|
|
||||||
# aggregate
|
|
||||||
|
|
||||||
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
|
||||||
|
|
||||||
# merge heads
|
# merge heads
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,10 @@ param = pytest.mark.parametrize
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
@param('pred_orig_latent', (False, True))
|
@param('pred_orig_latent', (False, True))
|
||||||
|
@param('gqa', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent
|
pred_orig_latent,
|
||||||
|
gqa
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
@ -17,7 +19,19 @@ def test_e2e(
|
|||||||
latents = tokenizer(x, return_latents = True)
|
latents = tokenizer(x, return_latents = True)
|
||||||
assert latents.shape[-1] == 32
|
assert latents.shape[-1] == 32
|
||||||
|
|
||||||
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent)
|
query_heads, heads = (16, 4) if gqa else (8, 8)
|
||||||
|
|
||||||
|
dynamics = DynamicsModel(
|
||||||
|
512,
|
||||||
|
dim_latent = 32,
|
||||||
|
num_signal_levels = 500,
|
||||||
|
num_step_sizes = 32,
|
||||||
|
pred_orig_latent = pred_orig_latent,
|
||||||
|
attn_kwargs = dict(
|
||||||
|
heads = heads,
|
||||||
|
query_heads = query_heads
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
signal_levels = torch.randint(0, 500, (2, 4))
|
signal_levels = torch.randint(0, 500, (2, 4))
|
||||||
step_sizes = torch.randint(0, 32, (2, 4))
|
step_sizes = torch.randint(0, 32, (2, 4))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user