diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 1b2a274..9443e4b 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1281,6 +1281,8 @@ class DynamicsModel(Module): latents = empty((batch_size, 0, *latent_shape), device = self.device) + past_context_noise = latents.clone() + # maybe return rewards if return_rewards_per_frame: @@ -1296,7 +1298,7 @@ class DynamicsModel(Module): for step in range(num_steps): signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) - noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8) + noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8) noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d') @@ -1326,6 +1328,8 @@ class DynamicsModel(Module): noised_latent += flow * (step_size / self.max_steps) + denoised_latent = noised_latent # it is now denoised + # take care of the rewards by predicting on the agent token embedding on the last denoising step if return_rewards_per_frame: @@ -1336,7 +1340,11 @@ class DynamicsModel(Module): # concat the denoised latent - latents = cat((latents, noised_latent), dim = 1) + latents = cat((latents, denoised_latent), dim = 1) + + # add new fixed context noise for the temporal consistency + + past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1) # restore state