softclamping in flex
This commit is contained in:
parent
8e7a35b89c
commit
67519a451d
@ -190,6 +190,20 @@ def flex_block_causal_mask(
|
|||||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||||
return block_mask
|
return block_mask
|
||||||
|
|
||||||
|
# for softclamping with flex attention
|
||||||
|
|
||||||
|
def softclamp_score_mod(value):
|
||||||
|
|
||||||
|
def inner(attn_logits, b, h, qi, ki):
|
||||||
|
attn_logits = attn_logits / value
|
||||||
|
attn_logits = torch.tanh(attn_logits)
|
||||||
|
attn_logits = attn_logits * value
|
||||||
|
return attn_logits
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
# todo - reuse the inner function from flex attn above with broadcasting
|
||||||
|
|
||||||
def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||||
blocks = ceil(seq_len / block_size)
|
blocks = ceil(seq_len / block_size)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user