take care of the loss weight proposed in eq 8
This commit is contained in:
parent
046f8927d1
commit
90bf19f076
@ -74,6 +74,12 @@ def l2norm(t):
|
||||
def softclamp(t, value = 50.):
|
||||
return (t / value).tanh() * value
|
||||
|
||||
# loss related
|
||||
|
||||
def ramp_weight(times, slope = 0.9, intercept = 0.1):
|
||||
# equation (8) paper, their "ramp" loss weighting
|
||||
return slope * times + intercept
|
||||
|
||||
# reinforcement learning related
|
||||
|
||||
# rewards
|
||||
@ -609,7 +615,8 @@ class DynamicsModel(Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
),
|
||||
ff_kwargs: dict = dict()
|
||||
ff_kwargs: dict = dict(),
|
||||
loss_weight_fn: Callable = ramp_weight,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -635,6 +642,8 @@ class DynamicsModel(Module):
|
||||
|
||||
self.pred_orig_latent = pred_orig_latent
|
||||
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
# they sum all the actions into a single token
|
||||
|
||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||
@ -683,11 +692,11 @@ class DynamicsModel(Module):
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
|
||||
times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
|
||||
|
||||
orig_latents = latents
|
||||
|
||||
latents = noise.lerp(latents, interp)
|
||||
latents = noise.lerp(latents, times)
|
||||
|
||||
# allow for original velocity pred
|
||||
# x-space as in paper is in else clause
|
||||
@ -749,7 +758,13 @@ class DynamicsModel(Module):
|
||||
if not flow_matching:
|
||||
return pred
|
||||
|
||||
return F.mse_loss(pred, pred_target)
|
||||
losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||
|
||||
if exists(self.loss_weight_fn):
|
||||
loss_weight = self.loss_weight_fn(times)
|
||||
losses = losses * loss_weight
|
||||
|
||||
return losses.mean()
|
||||
|
||||
# dreamer
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user