diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index b10ec76..a281e0f 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -14,6 +14,18 @@ import einx from einops import einsum, rearrange, repeat, reduce from einops.layers.torch import Rearrange +# flex attention - but will make sure it works if it is not available +# may also end up crafting own custom flash attention kernel for this work + +flex_attention = None + +try: + from torch.nn.attention.flex_attention import flex_attention, create_block_mask + if torch.cuda.is_available(): + flex_attention = torch.compile(flex_attention) +except ImportError: + pass + # constants LinearNoBias = partial(Linear, bias = False) @@ -52,6 +64,27 @@ class MultiHeadRMSNorm(Module): scale = (self.gamma + 1.) * self.scale return einx.multiply('... h n d, h d', normed, scale) +# masking related +# block causal mask (space fully attends within each block, while time is causal) + +def flex_block_causal_mask(seq_len, block_size): + + def create_mask(_, __, qi, ki): + q_block_index = qi // block_size + k_block_index = ki // block_size + return q_block_index >= k_block_index + + block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True) + return block_mask + +def nonflex_block_causal_mask(seq_len, block_size, device = None): + blocks = ceil(seq_len / block_size) + + causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril() + block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size1 = block_size) + + return block_causal_mask[:seq_len, :seq_len] + # attention class Attention(Module):