take care of blocked causal in video tokenizer, still need the special attention pattern for latents to and from though
This commit is contained in:
parent
6c994db341
commit
5c6be4d979
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from math import ceil
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -423,6 +424,7 @@ def naive_attend(
|
|||||||
softclamp_value = None,
|
softclamp_value = None,
|
||||||
scale = None,
|
scale = None,
|
||||||
causal = False,
|
causal = False,
|
||||||
|
causal_block_size = 1,
|
||||||
mask = None
|
mask = None
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -456,8 +458,19 @@ def naive_attend(
|
|||||||
sim = sim.masked_fill(~mask, mask_value)
|
sim = sim.masked_fill(~mask, mask_value)
|
||||||
|
|
||||||
if causal:
|
if causal:
|
||||||
|
is_blocked_causal = causal_block_size > 1
|
||||||
i, j = sim.shape[-2:]
|
i, j = sim.shape[-2:]
|
||||||
|
|
||||||
|
if is_blocked_causal:
|
||||||
|
i = ceil(i / causal_block_size)
|
||||||
|
j = ceil(j / causal_block_size)
|
||||||
|
|
||||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
||||||
|
|
||||||
|
if causal_block_size > 1:
|
||||||
|
causal_mask = repeat(causal_mask, 'i j -> (i b1) (j b2)', b1 = causal_block_size, b2 = causal_block_size)
|
||||||
|
causal_mask = causal_mask[:sim.shape[-2], :sim.shape[-1]]
|
||||||
|
|
||||||
sim = sim.masked_fill(causal_mask, mask_value)
|
sim = sim.masked_fill(causal_mask, mask_value)
|
||||||
|
|
||||||
# attend
|
# attend
|
||||||
@ -474,8 +487,13 @@ def naive_attend(
|
|||||||
|
|
||||||
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
|
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
|
||||||
|
|
||||||
def block_mask_causal(b, h, q, k):
|
def block_mask_causal(block_size):
|
||||||
return q >= k
|
|
||||||
|
def inner(b, h, q, k):
|
||||||
|
bq = q // block_size
|
||||||
|
bk = k // block_size
|
||||||
|
return bq >= bk
|
||||||
|
return inner
|
||||||
|
|
||||||
def agent_token_mask(q, k, seq_len, num_tokens):
|
def agent_token_mask(q, k, seq_len, num_tokens):
|
||||||
is_special_start_index = seq_len - num_tokens
|
is_special_start_index = seq_len - num_tokens
|
||||||
@ -519,6 +537,7 @@ def get_attend_fn(
|
|||||||
seq_len,
|
seq_len,
|
||||||
k_seq_len,
|
k_seq_len,
|
||||||
causal = False,
|
causal = False,
|
||||||
|
causal_block_size = 1,
|
||||||
softclamp_value = 50.,
|
softclamp_value = 50.,
|
||||||
num_agent_tokens = 0,
|
num_agent_tokens = 0,
|
||||||
device = None
|
device = None
|
||||||
@ -526,7 +545,7 @@ def get_attend_fn(
|
|||||||
if use_flex:
|
if use_flex:
|
||||||
# flex pathway
|
# flex pathway
|
||||||
|
|
||||||
block_mask_fn = block_mask_causal if causal else block_mask_noop
|
block_mask_fn = block_mask_causal(causal_block_size) if causal else block_mask_noop
|
||||||
|
|
||||||
if num_agent_tokens > 0:
|
if num_agent_tokens > 0:
|
||||||
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
|
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
|
||||||
@ -546,7 +565,7 @@ def get_attend_fn(
|
|||||||
|
|
||||||
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
|
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
|
||||||
|
|
||||||
attend_fn = partial(naive_attend, causal = causal, mask = mask, softclamp_value = softclamp_value)
|
attend_fn = partial(naive_attend, causal = causal, causal_block_size = causal_block_size, mask = mask, softclamp_value = softclamp_value)
|
||||||
|
|
||||||
return attend_fn
|
return attend_fn
|
||||||
|
|
||||||
@ -858,15 +877,25 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
||||||
|
|
||||||
|
space_seq_len = tokens.shape[-2]
|
||||||
|
|
||||||
# pack time
|
# pack time
|
||||||
|
|
||||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||||
|
|
||||||
|
seq_len = tokens.shape[1]
|
||||||
|
|
||||||
# attend hyper parameters
|
# attend hyper parameters
|
||||||
|
|
||||||
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
attend_kwargs = dict(
|
||||||
|
causal = True,
|
||||||
|
causal_block_size = space_seq_len,
|
||||||
|
softclamp_value = self.attn_softclamp_value
|
||||||
|
)
|
||||||
|
|
||||||
attend_fn = partial(naive_attend, **attend_kwargs)
|
use_flex = tokens.is_cuda and exists(flex_attention)
|
||||||
|
|
||||||
|
attend_fn = get_attend_fn(use_flex, seq_len, seq_len)
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
|
|
||||||
|
|||||||
@ -63,10 +63,12 @@ def test_symexp_two_hot():
|
|||||||
@param('causal', (False, True))
|
@param('causal', (False, True))
|
||||||
@param('softclamp_value', (50., None))
|
@param('softclamp_value', (50., None))
|
||||||
@param('num_agent_tokens', (0, 1))
|
@param('num_agent_tokens', (0, 1))
|
||||||
|
@param('causal_block_size', (1, 8))
|
||||||
def test_attend_factory(
|
def test_attend_factory(
|
||||||
causal,
|
causal,
|
||||||
softclamp_value,
|
softclamp_value,
|
||||||
num_agent_tokens
|
num_agent_tokens,
|
||||||
|
causal_block_size
|
||||||
):
|
):
|
||||||
|
|
||||||
from dreamer4.dreamer4 import get_attend_fn
|
from dreamer4.dreamer4 import get_attend_fn
|
||||||
@ -75,7 +77,15 @@ def test_attend_factory(
|
|||||||
k = torch.randn(1, 4, 1024, 512).cuda()
|
k = torch.randn(1, 4, 1024, 512).cuda()
|
||||||
v = torch.randn(1, 4, 1024, 512).cuda()
|
v = torch.randn(1, 4, 1024, 512).cuda()
|
||||||
|
|
||||||
attend_kwargs = dict(seq_len = 1024, k_seq_len = 1024, causal = causal, softclamp_value = softclamp_value, device = q.device, num_agent_tokens = num_agent_tokens)
|
attend_kwargs = dict(
|
||||||
|
seq_len = 1024,
|
||||||
|
k_seq_len = 1024,
|
||||||
|
causal = causal,
|
||||||
|
causal_block_size = causal_block_size,
|
||||||
|
softclamp_value = softclamp_value,
|
||||||
|
device = q.device,
|
||||||
|
num_agent_tokens = num_agent_tokens
|
||||||
|
)
|
||||||
|
|
||||||
attend = get_attend_fn(True, **attend_kwargs)
|
attend = get_attend_fn(True, **attend_kwargs)
|
||||||
flex_out = attend(q, k, v)
|
flex_out = attend(q, k, v)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user