take care of the loss weight proposed in eq 8

This commit is contained in:
lucidrains 2025-10-03 08:19:38 -07:00
parent 046f8927d1
commit 90bf19f076

View File

@ -74,6 +74,12 @@ def l2norm(t):
def softclamp(t, value = 50.):
return (t / value).tanh() * value
# loss related
def ramp_weight(times, slope = 0.9, intercept = 0.1):
# equation (8) paper, their "ramp" loss weighting
return slope * times + intercept
# reinforcement learning related
# rewards
@ -609,7 +615,8 @@ class DynamicsModel(Module):
dim_head = 64,
heads = 8,
),
ff_kwargs: dict = dict()
ff_kwargs: dict = dict(),
loss_weight_fn: Callable = ramp_weight,
):
super().__init__()
@ -635,6 +642,8 @@ class DynamicsModel(Module):
self.pred_orig_latent = pred_orig_latent
self.loss_weight_fn = loss_weight_fn
# they sum all the actions into a single token
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
@ -683,11 +692,11 @@ class DynamicsModel(Module):
noise = torch.randn_like(latents)
interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
orig_latents = latents
latents = noise.lerp(latents, interp)
latents = noise.lerp(latents, times)
# allow for original velocity pred
# x-space as in paper is in else clause
@ -749,7 +758,13 @@ class DynamicsModel(Module):
if not flow_matching:
return pred
return F.mse_loss(pred, pred_target)
losses = F.mse_loss(pred, pred_target, reduction = 'none')
if exists(self.loss_weight_fn):
loss_weight = self.loss_weight_fn(times)
losses = losses * loss_weight
return losses.mean()
# dreamer