first nail down the attention masking for the dynamics transformer model using a factory function
This commit is contained in:
parent
ca700ba8e1
commit
6c994db341
@ -416,29 +416,7 @@ def flex_block_mask(
|
||||
block_mask = create_block_mask(create_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, _compile = True)
|
||||
return block_mask
|
||||
|
||||
# for softclamping with flex attention
|
||||
|
||||
def softclamp_score_mod(value):
|
||||
|
||||
def inner(attn_logits, b, h, qi, ki):
|
||||
attn_logits = attn_logits / value
|
||||
attn_logits = torch.tanh(attn_logits)
|
||||
attn_logits = attn_logits * value
|
||||
return attn_logits
|
||||
|
||||
return inner
|
||||
|
||||
# todo - reuse the inner function from flex attn above with broadcasting
|
||||
|
||||
def block_causal_mask(seq_len, block_size, device = None):
|
||||
blocks = ceil(seq_len / block_size)
|
||||
|
||||
causal_mask = torch.ones((blocks, blocks), device = device, dtype = torch.bool).tril()
|
||||
block_causal_mask = repeat(causal_mask, 'i j -> (i block_size1) (j block_size2)', block_size1 = block_size, block_size2 = block_size)
|
||||
|
||||
return block_causal_mask[:seq_len, :seq_len]
|
||||
|
||||
# attend functions
|
||||
# naive attend
|
||||
|
||||
def naive_attend(
|
||||
q, k, v,
|
||||
@ -461,15 +439,15 @@ def naive_attend(
|
||||
|
||||
sim = einsum(q, k, 'b h g i d, b h j d -> b h g i j')
|
||||
|
||||
# scale and attention
|
||||
|
||||
sim = sim * scale
|
||||
|
||||
# softclamping a la gemma 3
|
||||
|
||||
if exists(softclamp_value):
|
||||
sim = softclamp(sim, softclamp_value)
|
||||
|
||||
# scale and attention
|
||||
|
||||
sim = sim * scale
|
||||
|
||||
# masking
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
@ -488,9 +466,89 @@ def naive_attend(
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum(attn, v, 'b h g i j, b h j d -> b h i d')
|
||||
out = einsum(attn, v, 'b h g i j, b h j d -> b h g i d')
|
||||
|
||||
return out
|
||||
# merge the groups
|
||||
|
||||
return rearrange(out, 'b h g i d -> b (h g) i d')
|
||||
|
||||
# flex attention related and factory function for attend depending on whether on cuda + flex attention available
|
||||
|
||||
def block_mask_causal(b, h, q, k):
|
||||
return q >= k
|
||||
|
||||
def agent_token_mask(q, k, seq_len, num_tokens):
|
||||
is_special_start_index = seq_len - num_tokens
|
||||
q_is_special = q >= is_special_start_index
|
||||
k_is_special = k >= is_special_start_index
|
||||
return ~(~q_is_special & k_is_special) # modality cannot attend to agent tokens
|
||||
|
||||
def block_mask_special_tokens_right(
|
||||
seq_len,
|
||||
num_tokens
|
||||
):
|
||||
def inner(b, h, q, k):
|
||||
return agent_token_mask(q, k, seq_len, num_tokens)
|
||||
return inner
|
||||
|
||||
def compose_mask(mask1, mask2):
|
||||
def inner(b, h, q, k):
|
||||
return mask1(b, h, q, k) & mask2(b, h, q, k)
|
||||
|
||||
return inner
|
||||
|
||||
def block_mask_noop(b, h, q, k):
|
||||
return b >= 0
|
||||
|
||||
def score_mod_softclamp(value):
|
||||
def inner(sim, b, h, q, k):
|
||||
if not exists(value):
|
||||
return sim
|
||||
|
||||
sim = sim / value
|
||||
sim = torch.tanh(sim)
|
||||
sim = sim * value
|
||||
return sim
|
||||
|
||||
return inner
|
||||
|
||||
# factory for attend function
|
||||
|
||||
def get_attend_fn(
|
||||
use_flex,
|
||||
seq_len,
|
||||
k_seq_len,
|
||||
causal = False,
|
||||
softclamp_value = 50.,
|
||||
num_agent_tokens = 0,
|
||||
device = None
|
||||
):
|
||||
if use_flex:
|
||||
# flex pathway
|
||||
|
||||
block_mask_fn = block_mask_causal if causal else block_mask_noop
|
||||
|
||||
if num_agent_tokens > 0:
|
||||
agent_block_mask = block_mask_special_tokens_right(k_seq_len, num_agent_tokens)
|
||||
block_mask_fn = compose_mask(block_mask_fn, agent_block_mask)
|
||||
|
||||
block_mask = create_block_mask(block_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = k_seq_len)
|
||||
|
||||
score_mod = score_mod_softclamp(softclamp_value)
|
||||
attend_fn = partial(flex_attention, block_mask = block_mask, score_mod = score_mod, enable_gqa = True)
|
||||
else:
|
||||
# naive pathway
|
||||
|
||||
mask = None
|
||||
if num_agent_tokens > 0:
|
||||
q_seq = torch.arange(seq_len, device = device)[:, None]
|
||||
k_seq = torch.arange(k_seq_len, device = device)[None, :]
|
||||
|
||||
mask = agent_token_mask(q_seq, k_seq, k_seq_len, num_agent_tokens)
|
||||
|
||||
attend_fn = partial(naive_attend, causal = causal, mask = mask, softclamp_value = softclamp_value)
|
||||
|
||||
return attend_fn
|
||||
|
||||
# attention
|
||||
|
||||
@ -521,7 +579,7 @@ class Attention(Module):
|
||||
|
||||
self.to_q = LinearNoBias(dim, dim_q_inner)
|
||||
self.to_kv = LinearNoBias(dim, dim_kv_inner * 2)
|
||||
self.to_out = LinearNoBias(dim_kv_inner, dim)
|
||||
self.to_out = LinearNoBias(dim_q_inner, dim)
|
||||
|
||||
# stability related
|
||||
|
||||
@ -949,6 +1007,15 @@ class DynamicsModel(Module):
|
||||
|
||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
# calculate "space" seq len
|
||||
|
||||
self.space_seq_len = (
|
||||
1 # action / agent token
|
||||
+ 1 # signal + step
|
||||
+ num_register_tokens
|
||||
+ num_spatial_tokens
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
@ -1013,7 +1080,7 @@ class DynamicsModel(Module):
|
||||
|
||||
latents = self.video_tokenizer.tokenize(video)
|
||||
|
||||
time = latents.shape[1]
|
||||
time, device = latents.shape[1], latents.device
|
||||
|
||||
# flow related
|
||||
|
||||
@ -1070,11 +1137,15 @@ class DynamicsModel(Module):
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
attend_kwargs = dict(softclamp_value = self.attn_softclamp_value)
|
||||
seq_len = tokens.shape[1]
|
||||
|
||||
space_attend = partial(naive_attend, causal = False, **attend_kwargs)
|
||||
use_flex = exists(flex_attention) and tokens.is_cuda
|
||||
|
||||
time_attend = partial(naive_attend, causal = True, **attend_kwargs)
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_agent_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
# rotary
|
||||
|
||||
|
||||
@ -58,3 +58,29 @@ def test_symexp_two_hot():
|
||||
recon_values = two_hot_encoder.logits_to_scalar_value(encoded)
|
||||
|
||||
assert torch.allclose(recon_values, values, atol = 1e-6)
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason = 'no cuda')
|
||||
@param('causal', (False, True))
|
||||
@param('softclamp_value', (50., None))
|
||||
@param('num_agent_tokens', (0, 1))
|
||||
def test_attend_factory(
|
||||
causal,
|
||||
softclamp_value,
|
||||
num_agent_tokens
|
||||
):
|
||||
|
||||
from dreamer4.dreamer4 import get_attend_fn
|
||||
|
||||
q = torch.randn(1, 8, 1024, 512).cuda()
|
||||
k = torch.randn(1, 4, 1024, 512).cuda()
|
||||
v = torch.randn(1, 4, 1024, 512).cuda()
|
||||
|
||||
attend_kwargs = dict(seq_len = 1024, k_seq_len = 1024, causal = causal, softclamp_value = softclamp_value, device = q.device, num_agent_tokens = num_agent_tokens)
|
||||
|
||||
attend = get_attend_fn(True, **attend_kwargs)
|
||||
flex_out = attend(q, k, v)
|
||||
|
||||
attend = get_attend_fn(False, **attend_kwargs)
|
||||
out = attend(q, k, v)
|
||||
|
||||
assert torch.allclose(flex_out, out, atol = 1e-6)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user