softclamping in flex
This commit is contained in:
parent
8e7a35b89c
commit
67519a451d
@ -190,6 +190,20 @@ def flex_block_causal_mask(
|
||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||
return block_mask
|
||||
|
||||
# for softclamping with flex attention
|
||||
|
||||
def softclamp_score_mod(value):
|
||||
|
||||
def inner(attn_logits, b, h, qi, ki):
|
||||
attn_logits = attn_logits / value
|
||||
attn_logits = torch.tanh(attn_logits)
|
||||
attn_logits = attn_logits * value
|
||||
return attn_logits
|
||||
|
||||
return inner
|
||||
|
||||
# todo - reuse the inner function from flex attn above with broadcasting
|
||||
|
||||
def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||
blocks = ceil(seq_len / block_size)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user