ready the block causal mask
This commit is contained in:
parent
2e92c0121a
commit
c979883f21
@ -14,6 +14,18 @@ import einx
|
||||
from einops import einsum, rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# flex attention - but will make sure it works if it is not available
|
||||
# may also end up crafting own custom flash attention kernel for this work
|
||||
|
||||
flex_attention = None
|
||||
|
||||
try:
|
||||
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
if torch.cuda.is_available():
|
||||
flex_attention = torch.compile(flex_attention)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# constants
|
||||
|
||||
LinearNoBias = partial(Linear, bias = False)
|
||||
@ -52,6 +64,27 @@ class MultiHeadRMSNorm(Module):
|
||||
scale = (self.gamma + 1.) * self.scale
|
||||
return einx.multiply('... h n d, h d', normed, scale)
|
||||
|
||||
# masking related
|
||||
# block causal mask (space fully attends within each block, while time is causal)
|
||||
|
||||
def flex_block_causal_mask(seq_len, block_size):
|
||||
|
||||
def create_mask(_, __, qi, ki):
|
||||
q_block_index = qi // block_size
|
||||
k_block_index = ki // block_size
|
||||
return q_block_index >= k_block_index
|
||||
|
||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||
return block_mask
|
||||
|
||||
def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||
blocks = ceil(seq_len / block_size)
|
||||
|
||||
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
|
||||
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size1 = block_size)
|
||||
|
||||
return block_causal_mask[:seq_len, :seq_len]
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(Module):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user