the function for generating the MTP targets, as well as the mask for the losses

This commit is contained in:
lucidrains 2025-10-18 08:04:51 -07:00
parent 83cfd2cd1b
commit 5fc0022bbf
2 changed files with 38 additions and 0 deletions

View File

@ -173,6 +173,28 @@ def l2norm(t):
def softclamp(t, value = 50.):
return (t / value).tanh() * value
def create_multi_token_prediction_targets(
t, # (b t ...)
steps_future
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
batch, seq_len, device = *t.shape[:2], t.device
batch_arange = arange(batch, device = device)
seq_arange = arange(seq_len, device = device)[1:]
steps_arange = arange(steps_future, device = device)
indices = add('t, steps -> t steps', seq_arange, steps_arange)
mask = indices < seq_len
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
indices[~mask] = 0
mask = repeat(mask, 't steps -> b t steps', b = batch)
return t[batch_arange, indices], mask
# loss related
class LPIPSLoss(Module):

View File

@ -385,3 +385,19 @@ def test_action_embedder():
)
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
def test_mtp():
from dreamer4.dreamer4 import create_multi_token_prediction_targets
rewards = torch.randn(3, 16) # (b t)
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
assert reward_targets.shape == (3, 15, 3)
assert mask.shape == (3, 15, 3)
actions = torch.randint(0, 10, (3, 16, 2))
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
assert action_targets.shape == (3, 15, 3, 2)
assert mask.shape == (3, 15, 3)