diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9d34607..f843f1a 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -190,6 +190,20 @@ def flex_block_causal_mask( block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) return block_mask +# for softclamping with flex attention + +def softclamp_score_mod(value): + + def inner(attn_logits, b, h, qi, ki): + attn_logits = attn_logits / value + attn_logits = torch.tanh(attn_logits) + attn_logits = attn_logits * value + return attn_logits + + return inner + +# todo - reuse the inner function from flex attn above with broadcasting + def nonflex_block_causal_mask(seq_len, block_size, device = None): blocks = ceil(seq_len / block_size)