the function for generating the MTP targets, as well as the mask for the losses
This commit is contained in:
parent
83cfd2cd1b
commit
5fc0022bbf
@ -173,6 +173,28 @@ def l2norm(t):
|
|||||||
def softclamp(t, value = 50.):
|
def softclamp(t, value = 50.):
|
||||||
return (t / value).tanh() * value
|
return (t / value).tanh() * value
|
||||||
|
|
||||||
|
def create_multi_token_prediction_targets(
|
||||||
|
t, # (b t ...)
|
||||||
|
steps_future
|
||||||
|
|
||||||
|
): # (b t-1 steps ...), (b t-1 steps) - targets and the mask, where mask is False for padding
|
||||||
|
|
||||||
|
batch, seq_len, device = *t.shape[:2], t.device
|
||||||
|
|
||||||
|
batch_arange = arange(batch, device = device)
|
||||||
|
seq_arange = arange(seq_len, device = device)[1:]
|
||||||
|
steps_arange = arange(steps_future, device = device)
|
||||||
|
|
||||||
|
indices = add('t, steps -> t steps', seq_arange, steps_arange)
|
||||||
|
mask = indices < seq_len
|
||||||
|
|
||||||
|
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
||||||
|
|
||||||
|
indices[~mask] = 0
|
||||||
|
mask = repeat(mask, 't steps -> b t steps', b = batch)
|
||||||
|
|
||||||
|
return t[batch_arange, indices], mask
|
||||||
|
|
||||||
# loss related
|
# loss related
|
||||||
|
|
||||||
class LPIPSLoss(Module):
|
class LPIPSLoss(Module):
|
||||||
|
|||||||
@ -385,3 +385,19 @@ def test_action_embedder():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
|
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
|
||||||
|
|
||||||
|
def test_mtp():
|
||||||
|
from dreamer4.dreamer4 import create_multi_token_prediction_targets
|
||||||
|
|
||||||
|
rewards = torch.randn(3, 16) # (b t)
|
||||||
|
|
||||||
|
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
||||||
|
|
||||||
|
assert reward_targets.shape == (3, 15, 3)
|
||||||
|
assert mask.shape == (3, 15, 3)
|
||||||
|
|
||||||
|
actions = torch.randint(0, 10, (3, 16, 2))
|
||||||
|
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
||||||
|
|
||||||
|
assert action_targets.shape == (3, 15, 3, 2)
|
||||||
|
assert mask.shape == (3, 15, 3)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user