oops
This commit is contained in:
parent
c979883f21
commit
ceb1af263e
@ -81,7 +81,7 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None):
|
||||
blocks = ceil(seq_len / block_size)
|
||||
|
||||
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
|
||||
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size1 = block_size)
|
||||
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size2 = block_size)
|
||||
|
||||
return block_causal_mask[:seq_len, :seq_len]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user