This commit is contained in:
lucidrains 2025-10-01 09:49:04 -07:00
parent c979883f21
commit ceb1af263e

View File

@ -81,7 +81,7 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None):
blocks = ceil(seq_len / block_size)
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size1 = block_size)
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size2 = block_size)
return block_causal_mask[:seq_len, :seq_len]