take care of blocked causal in video tokenizer, still need the special attention pattern for latents to and from though

This commit is contained in:
lucidrains 2025-10-04 12:03:50 -07:00
parent 6c994db341
commit 5c6be4d979
2 changed files with 47 additions and 8 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import math
from math import ceil
from collections import namedtuple
from functools import partial
@ -423,6 +424,7 @@ def naive_attend(
softclamp_value = None,
scale = None,
causal = False,
causal_block_size = 1,
mask = None
):
@ -456,8 +458,19 @@ def naive_attend(
sim = sim.masked_fill(~mask, mask_value)
if causal:
is_blocked_causal = causal_block_size > 1
i, j = sim.shape[-2:]
if is_blocked_causal:
i = ceil(i / causal_block_size)
j = ceil(j / causal_block_size)
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
if causal_block_size > 1:
causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
sim = sim.masked_fill(causal_mask, mask_value)
# attend
@ -474,8 +487,13 @@ def naive_attend(
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
def block_mask_causal(b, h, q, k):
return q >= k
def block_mask_causal(block_size):
def inner(b, h, q, k):
bq = q // block_size
bk = k // block_size
return bq >= bk
return inner
def agent_token_mask(q, k, seq_len, num_tokens):
is_special_start_index = seq_len - num_tokens
@ -519,6 +537,7 @@ def get_attend_fn(
seq_len,
k_seq_len,
causal = False,
causal_block_size = 1,
softclamp_value = 50.,
num_agent_tokens = 0,
device = None
@ -526,7 +545,7 @@ def get_attend_fn(
if use_flex:
# flex pathway
block_mask_fn = block_mask_causal if causal else block_mask_noop
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
if num_agent_tokens > 0:
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
@ -546,7 +565,7 @@ def get_attend_fn(
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
attend_fn = partial(naive_attend, causal = causal, mask = mask, softclamp_value = softclamp_value)
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
return attend_fn
@ -858,15 +877,25 @@ class VideoTokenizer(Module):
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
space_seq_len = tokens.shape[-2]
# pack time
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
seq_len = tokens.shape[1]
# attend hyper parameters
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
attend_kwargs = dict(
causal = True,
causal_block_size = space_seq_len,
softclamp_value = self.attn_softclamp_value
)
attend_fn = partial(naive_attend, **attend_kwargs)
use_flex = tokens.is_cuda and exists(flex_attention)
attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
# encoder

View File

@ -63,10 +63,12 @@ def test_symexp_two_hot():
@param('causal', (False, True))
@param('softclamp_value', (50., None))
@param('num_agent_tokens', (0, 1))
@param('causal_block_size', (1, 8))
def test_attend_factory(
causal,
softclamp_value,
num_agent_tokens
num_agent_tokens,
causal_block_size
):
from dreamer4.dreamer4 import get_attend_fn
@ -75,7 +77,15 @@ def test_attend_factory(
k = torch.randn(1, 4, 1024, 512).cuda()
v = torch.randn(1, 4, 1024, 512).cuda()
attend_kwargs = dict(seq_len = 1024, k_seq_len = 1024, causal = causal, softclamp_value = softclamp_value, device = q.device, num_agent_tokens = num_agent_tokens)
attend_kwargs = dict(
seq_len = 1024,
k_seq_len = 1024,
causal = causal,
causal_block_size = causal_block_size,
softclamp_value = softclamp_value,
device = q.device,
num_agent_tokens = num_agent_tokens
)
attend = get_attend_fn(True, **attend_kwargs)
flex_out = attend(q, k, v)