diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 17bc894..aef87f5 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -74,6 +74,12 @@ def l2norm(t): def softclamp(t, value = 50.): return (t / value).tanh() * value +# loss related + +def ramp_weight(times, slope = 0.9, intercept = 0.1): + # equation (8) paper, their "ramp" loss weighting + return slope * times + intercept + # reinforcement learning related # rewards @@ -609,7 +615,8 @@ class DynamicsModel(Module): dim_head = 64, heads = 8, ), - ff_kwargs: dict = dict() + ff_kwargs: dict = dict(), + loss_weight_fn: Callable = ramp_weight, ): super().__init__() @@ -635,6 +642,8 @@ class DynamicsModel(Module): self.pred_orig_latent = pred_orig_latent + self.loss_weight_fn = loss_weight_fn + # they sum all the actions into a single token self.action_learned_embed = Parameter(torch.randn(dim) * 1e-2) @@ -683,11 +692,11 @@ class DynamicsModel(Module): noise = torch.randn_like(latents) - interp = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') + times = rearrange(signal_levels.float() / self.num_signal_levels, 'b t -> b t 1') orig_latents = latents - latents = noise.lerp(latents, interp) + latents = noise.lerp(latents, times) # allow for original velocity pred # x-space as in paper is in else clause @@ -749,7 +758,13 @@ class DynamicsModel(Module): if not flow_matching: return pred - return F.mse_loss(pred, pred_target) + losses = F.mse_loss(pred, pred_target, reduction = 'none') + + if exists(self.loss_weight_fn): + loss_weight = self.loss_weight_fn(times) + losses = losses * loss_weight + + return losses.mean() # dreamer