diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index c263d26..e966349 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -173,6 +173,28 @@ def l2norm(t): def softclamp(t, value = 50.): return (t / value).tanh() * value +def create_multi_token_prediction_targets( + t, # (b t ...) + steps_future + +): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding + + batch, seq_len, device = *t.shape[:2], t.device + + batch_arange = arange(batch, device = device) + seq_arange = arange(seq_len, device = device)[1:] + steps_arange = arange(steps_future, device = device) + + indices = add('t, steps -> t steps', seq_arange, steps_arange) + mask = indices < seq_len + + batch_arange = rearrange(batch_arange, 'b -> b 1 1') + + indices[~mask] = 0 + mask = repeat(mask, 't steps -> b t steps', b = batch) + + return t[batch_arange, indices], mask + # loss related class LPIPSLoss(Module): diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 7a0a0ca..943dcbe 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -385,3 +385,19 @@ def test_action_embedder(): ) assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5) + +def test_mtp(): + from dreamer4.dreamer4 import create_multi_token_prediction_targets + + rewards = torch.randn(3, 16) # (b t) + + reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead + + assert reward_targets.shape == (3, 15, 3) + assert mask.shape == (3, 15, 3) + + actions = torch.randint(0, 10, (3, 16, 2)) + action_targets, mask = create_multi_token_prediction_targets(actions, 3) + + assert action_targets.shape == (3, 15, 3, 2) + assert mask.shape == (3, 15, 3)