ready the block causal mask

This commit is contained in:
lucidrains 2025-10-01 09:45:54 -07:00
parent 2e92c0121a
commit c979883f21

View File

@ -14,6 +14,18 @@ import einx
from einops import einsum, rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# flex attention - but will make sure it works if it is not available
# may also end up crafting own custom flash attention kernel for this work
flex_attention = None
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
if torch.cuda.is_available():
flex_attention = torch.compile(flex_attention)
except ImportError:
pass
# constants
LinearNoBias = partial(Linear, bias = False)
@ -52,6 +64,27 @@ class MultiHeadRMSNorm(Module):
scale = (self.gamma + 1.) * self.scale
return einx.multiply('... h n d, h d', normed, scale)
# masking related
# block causal mask (space fully attends within each block, while time is causal)
def flex_block_causal_mask(seq_len, block_size):
def create_mask(_, __, qi, ki):
q_block_index = qi // block_size
k_block_index = ki // block_size
return q_block_index >= k_block_index
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask
def nonflex_block_causal_mask(seq_len, block_size, device = None):
blocks = ceil(seq_len / block_size)
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size1 = block_size)
return block_causal_mask[:seq_len, :seq_len]
# attention
class Attention(Module):