softclamping in flex

This commit is contained in:
lucidrains 2025-10-01 12:19:41 -07:00
parent 8e7a35b89c
commit 67519a451d

View File

@ -190,6 +190,20 @@ def flex_block_causal_mask(
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask return block_mask
# for softclamping with flex attention
def softclamp_score_mod(value):
def inner(attn_logits, b, h, qi, ki):
attn_logits = attn_logits / value
attn_logits = torch.tanh(attn_logits)
attn_logits = attn_logits * value
return attn_logits
return inner
# todo - reuse the inner function from flex attn above with broadcasting
def nonflex_block_causal_mask(seq_len, block_size, device = None): def nonflex_block_causal_mask(seq_len, block_size, device = None):
blocks = ceil(seq_len / block_size) blocks = ceil(seq_len / block_size)