diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 0ab6a76..725f167 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from math import ceil from collections import namedtuple from functools import partial @@ -423,6 +424,7 @@ def naive_attend( softclamp_value = None, scale = None, causal = False, + causal_block_size = 1, mask = None ): @@ -456,8 +458,19 @@ def naive_attend( sim = sim.masked_fill(~mask, mask_value) if causal: + is_blocked_causal = causal_block_size > 1 i, j = sim.shape[-2:] + + if is_blocked_causal: + i = ceil(i / causal_block_size) + j = ceil(j / causal_block_size) + causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1) + + if causal_block_size > 1: + causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size) + causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]] + sim = sim.masked_fill(causal_mask, mask_value) # attend @@ -474,8 +487,13 @@ def naive_attend( # flex attention related and factory function for attend depending on whether on cuda + flex attention available -def block_mask_causal(b, h, q, k): - return q >= k +def block_mask_causal(block_size): + + def inner(b, h, q, k): + bq = q // block_size + bk = k // block_size + return bq >= bk + return inner def agent_token_mask(q, k, seq_len, num_tokens): is_special_start_index = seq_len - num_tokens @@ -519,6 +537,7 @@ def get_attend_fn( seq_len, k_seq_len, causal = False, + causal_block_size = 1, softclamp_value = 50., num_agent_tokens = 0, device = None @@ -526,7 +545,7 @@ def get_attend_fn( if use_flex: # flex pathway - block_mask_fn = block_mask_causal if causal else block_mask_noop + block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop if num_agent_tokens > 0: agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens) @@ -546,7 +565,7 @@ def get_attend_fn( mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens) - attend_fn = partial(naive_attend, causal = causal, mask = mask, softclamp_value = softclamp_value) + attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value) return attend_fn @@ -858,15 +877,25 @@ class VideoTokenizer(Module): tokens, packed_latent_shape = pack((tokens, latents), 'b t * d') + space_seq_len = tokens.shape[-2] + # pack time tokens, inverse_pack_time = pack_one(tokens, 'b * d') + seq_len = tokens.shape[1] + # attend hyper parameters - attend_kwargs = dict(softclamp_value = self.attn_softclamp_value) + attend_kwargs = dict( + causal = True, + causal_block_size = space_seq_len, + softclamp_value = self.attn_softclamp_value + ) - attend_fn = partial(naive_attend, **attend_kwargs) + use_flex = tokens.is_cuda and exists(flex_attention) + + attend_fn = get_attend_fn(use_flex, seq_len, seq_len) # encoder diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 0a7a2e3..be28f06 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -63,10 +63,12 @@ def test_symexp_two_hot(): @param('causal', (False, True)) @param('softclamp_value', (50., None)) @param('num_agent_tokens', (0, 1)) +@param('causal_block_size', (1, 8)) def test_attend_factory( causal, softclamp_value, - num_agent_tokens + num_agent_tokens, + causal_block_size ): from dreamer4.dreamer4 import get_attend_fn @@ -75,7 +77,15 @@ def test_attend_factory( k = torch.randn(1, 4, 1024, 512).cuda() v = torch.randn(1, 4, 1024, 512).cuda() - attend_kwargs = dict(seq_len = 1024, k_seq_len = 1024, causal = causal, softclamp_value = softclamp_value, device = q.device, num_agent_tokens = num_agent_tokens) + attend_kwargs = dict( + seq_len = 1024, + k_seq_len = 1024, + causal = causal, + causal_block_size = causal_block_size, + softclamp_value = softclamp_value, + device = q.device, + num_agent_tokens = num_agent_tokens + ) attend = get_attend_fn(True, **attend_kwargs) flex_out = attend(q, k, v)