ready the block causal mask
This commit is contained in:
parent
2e92c0121a
commit
c979883f21
@ -14,6 +14,18 @@ import einx
|
|||||||
from einops import einsum, rearrange, repeat, reduce
|
from einops import einsum, rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
|
# flex attention - but will make sure it works if it is not available
|
||||||
|
# may also end up crafting own custom flash attention kernel for this work
|
||||||
|
|
||||||
|
flex_attention = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
flex_attention = torch.compile(flex_attention)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
LinearNoBias = partial(Linear, bias = False)
|
LinearNoBias = partial(Linear, bias = False)
|
||||||
@ -52,6 +64,27 @@ class MultiHeadRMSNorm(Module):
|
|||||||
scale = (self.gamma + 1.) * self.scale
|
scale = (self.gamma + 1.) * self.scale
|
||||||
return einx.multiply('... h n d, h d', normed, scale)
|
return einx.multiply('... h n d, h d', normed, scale)
|
||||||
|
|
||||||
|
# masking related
|
||||||
|
# block causal mask (space fully attends within each block, while time is causal)
|
||||||
|
|
||||||
|
def flex_block_causal_mask(seq_len, block_size):
|
||||||
|
|
||||||
|
def create_mask(_, __, qi, ki):
|
||||||
|
q_block_index = qi // block_size
|
||||||
|
k_block_index = ki // block_size
|
||||||
|
return q_block_index >= k_block_index
|
||||||
|
|
||||||
|
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||||
|
return block_mask
|
||||||
|
|
||||||
|
def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||||
|
blocks = ceil(seq_len / block_size)
|
||||||
|
|
||||||
|
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
|
||||||
|
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size1 = block_size)
|
||||||
|
|
||||||
|
return block_causal_mask[:seq_len, :seq_len]
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
class Attention(Module):
|
class Attention(Module):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user