first nail down the attention masking for the dynamics transformer model using a factory function

This commit is contained in:
lucidrains 2025-10-04 11:20:57 -07:00
parent ca700ba8e1
commit 6c994db341
2 changed files with 131 additions and 34 deletions

View File

@ -416,29 +416,7 @@ def flex_block_mask(
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
return block_mask
# for softclamping with flex attention
def softclamp_score_mod(value):
def inner(attn_logits, b, h, qi, ki):
attn_logits = attn_logits / value
attn_logits = torch.tanh(attn_logits)
attn_logits = attn_logits * value
return attn_logits
return inner
# todo - reuse the inner function from flex attn above with broadcasting
def block_causal_mask(seq_len, block_size, device = None):
blocks = ceil(seq_len / block_size)
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size2 = block_size)
return block_causal_mask[:seq_len, :seq_len]
# attend functions
# naive attend
def naive_attend(
q, k, v,
@ -461,15 +439,15 @@ def naive_attend(
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
# scale and attention
sim = sim * scale
# softclamping a la gemma 3
if exists(softclamp_value):
sim = softclamp(sim, softclamp_value)
# scale and attention
sim = sim * scale
# masking
mask_value = -torch.finfo(sim.dtype).max
@ -488,9 +466,89 @@ def naive_attend(
# aggregate
out = einsum(attn, v, 'b h g i j, b h j d -> b h i d')
out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d')
return out
# merge the groups
return rearrange(out, 'b h g i d -> b (h g) i d')
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
def block_mask_causal(b, h, q, k):
return q >= k
def agent_token_mask(q, k, seq_len, num_tokens):
is_special_start_index = seq_len - num_tokens
q_is_special = q >= is_special_start_index
k_is_special = k >= is_special_start_index
return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
def block_mask_special_tokens_right(
seq_len,
num_tokens
):
def inner(b, h, q, k):
return agent_token_mask(q, k, seq_len, num_tokens)
return inner
def compose_mask(mask1, mask2):
def inner(b, h, q, k):
return mask1(b, h, q, k) & mask2(b, h, q, k)
return inner
def block_mask_noop(b, h, q, k):
return b >= 0
def score_mod_softclamp(value):
def inner(sim, b, h, q, k):
if not exists(value):
return sim
sim = sim / value
sim = torch.tanh(sim)
sim = sim * value
return sim
return inner
# factory for attend function
def get_attend_fn(
use_flex,
seq_len,
k_seq_len,
causal = False,
softclamp_value = 50.,
num_agent_tokens = 0,
device = None
):
if use_flex:
# flex pathway
block_mask_fn = block_mask_causal if causal else block_mask_noop
if num_agent_tokens > 0:
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
block_mask_fn = compose_mask(block_mask_fn, agent_block_mask)
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
score_mod = score_mod_softclamp(softclamp_value)
attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
else:
# naive pathway
mask = None
if num_agent_tokens > 0:
q_seq = torch.arange(seq_len, device = device)[:, None]
k_seq = torch.arange(k_seq_len, device = device)[None, :]
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
attend_fn = partial(naive_attend, causal = causal, mask = mask, softclamp_value = softclamp_value)
return attend_fn
# attention
@ -521,7 +579,7 @@ class Attention(Module):
self.to_q = LinearNoBias(dim, dim_q_inner)
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
self.to_out = LinearNoBias(dim_kv_inner, dim)
self.to_out = LinearNoBias(dim_q_inner, dim)
# stability related
@ -949,6 +1007,15 @@ class DynamicsModel(Module):
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
# calculate "space" seq len
self.space_seq_len = (
1 # action / agent token
+ 1 # signal + step
+ num_register_tokens
+ num_spatial_tokens
)
# attention
self.attn_softclamp_value = attn_softclamp_value
@ -1013,7 +1080,7 @@ class DynamicsModel(Module):
latents = self.video_tokenizer.tokenize(video)
time = latents.shape[1]
time, device = latents.shape[1], latents.device
# flow related
@ -1070,11 +1137,15 @@ class DynamicsModel(Module):
# attend functions for space and time
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
seq_len = tokens.shape[1]
space_attend = partial(naive_attend, causal = False, **attend_kwargs)
use_flex = exists(flex_attention) and tokens.is_cuda
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
# rotary

View File

@ -58,3 +58,29 @@ def test_symexp_two_hot():
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
assert torch.allclose(recon_values, values, atol = 1e-6)
@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda')
@param('causal', (False, True))
@param('softclamp_value', (50., None))
@param('num_agent_tokens', (0, 1))
def test_attend_factory(
causal,
softclamp_value,
num_agent_tokens
):
from dreamer4.dreamer4 import get_attend_fn
q = torch.randn(1, 8, 1024, 512).cuda()
k = torch.randn(1, 4, 1024, 512).cuda()
v = torch.randn(1, 4, 1024, 512).cuda()
attend_kwargs = dict(seq_len = 1024, k_seq_len = 1024, causal = causal, softclamp_value = softclamp_value, device = q.device, num_agent_tokens = num_agent_tokens)
attend = get_attend_fn(True, **attend_kwargs)
flex_out = attend(q, k, v)
attend = get_attend_fn(False, **attend_kwargs)
out = attend(q, k, v)
assert torch.allclose(flex_out, out, atol = 1e-6)