the function for generating the MTP targets, as well as the mask for the losses
This commit is contained in:
parent
83cfd2cd1b
commit
5fc0022bbf
@ -173,6 +173,28 @@ def l2norm(t):
|
||||
def softclamp(t, value = 50.):
|
||||
return (t / value).tanh() * value
|
||||
|
||||
def create_multi_token_prediction_targets(
|
||||
t, # (b t ...)
|
||||
steps_future
|
||||
|
||||
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
||||
|
||||
batch, seq_len, device = *t.shape[:2], t.device
|
||||
|
||||
batch_arange = arange(batch, device = device)
|
||||
seq_arange = arange(seq_len, device = device)[1:]
|
||||
steps_arange = arange(steps_future, device = device)
|
||||
|
||||
indices = add('t, steps -> t steps', seq_arange, steps_arange)
|
||||
mask = indices < seq_len
|
||||
|
||||
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
||||
|
||||
indices[~mask] = 0
|
||||
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
||||
|
||||
return t[batch_arange, indices], mask
|
||||
|
||||
# loss related
|
||||
|
||||
class LPIPSLoss(Module):
|
||||
|
||||
@ -385,3 +385,19 @@ def test_action_embedder():
|
||||
)
|
||||
|
||||
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
|
||||
|
||||
def test_mtp():
|
||||
from dreamer4.dreamer4 import create_multi_token_prediction_targets
|
||||
|
||||
rewards = torch.randn(3, 16) # (b t)
|
||||
|
||||
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
||||
|
||||
assert reward_targets.shape == (3, 15, 3)
|
||||
assert mask.shape == (3, 15, 3)
|
||||
|
||||
actions = torch.randint(0, 10, (3, 16, 2))
|
||||
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
||||
|
||||
assert action_targets.shape == (3, 15, 3, 2)
|
||||
assert mask.shape == (3, 15, 3)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user