diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a281e0f..eae8877 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -81,7 +81,7 @@ def nonflex_block_causal_mask(seq_len, block_size, device = None): blocks = ceil(seq_len / block_size) causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril() - block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size1 = block_size) + block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size2 = block_size) return block_causal_mask[:seq_len, :seq_len]