take care of the loss weight proposed in eq 8
This commit is contained in:
parent
046f8927d1
commit
90bf19f076
@ -74,6 +74,12 @@ def l2norm(t):
|
|||||||
def softclamp(t, value = 50.):
|
def softclamp(t, value = 50.):
|
||||||
return (t / value).tanh() * value
|
return (t / value).tanh() * value
|
||||||
|
|
||||||
|
# loss related
|
||||||
|
|
||||||
|
def ramp_weight(times, slope = 0.9, intercept = 0.1):
|
||||||
|
# equation (8) paper, their "ramp" loss weighting
|
||||||
|
return slope * times + intercept
|
||||||
|
|
||||||
# reinforcement learning related
|
# reinforcement learning related
|
||||||
|
|
||||||
# rewards
|
# rewards
|
||||||
@ -609,7 +615,8 @@ class DynamicsModel(Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
),
|
),
|
||||||
ff_kwargs: dict = dict()
|
ff_kwargs: dict = dict(),
|
||||||
|
loss_weight_fn: Callable = ramp_weight,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -635,6 +642,8 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
self.pred_orig_latent = pred_orig_latent
|
self.pred_orig_latent = pred_orig_latent
|
||||||
|
|
||||||
|
self.loss_weight_fn = loss_weight_fn
|
||||||
|
|
||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
|
|
||||||
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2)
|
||||||
@ -683,11 +692,11 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
|
|
||||||
interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
|
times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1')
|
||||||
|
|
||||||
orig_latents = latents
|
orig_latents = latents
|
||||||
|
|
||||||
latents = noise.lerp(latents, interp)
|
latents = noise.lerp(latents, times)
|
||||||
|
|
||||||
# allow for original velocity pred
|
# allow for original velocity pred
|
||||||
# x-space as in paper is in else clause
|
# x-space as in paper is in else clause
|
||||||
@ -749,7 +758,13 @@ class DynamicsModel(Module):
|
|||||||
if not flow_matching:
|
if not flow_matching:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
return F.mse_loss(pred, pred_target)
|
losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
||||||
|
|
||||||
|
if exists(self.loss_weight_fn):
|
||||||
|
loss_weight = self.loss_weight_fn(times)
|
||||||
|
losses = losses * loss_weight
|
||||||
|
|
||||||
|
return losses.mean()
|
||||||
|
|
||||||
# dreamer
|
# dreamer
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user