makes more sense for the noise to be fixed

This commit is contained in:
lucidrains 2025-10-08 07:17:05 -07:00
parent 9c56ba0c9d
commit a50e360502

View File

@ -1281,6 +1281,8 @@ class DynamicsModel(Module):
latents = empty((batch_size, 0, *latent_shape), device = self.device)
past_context_noise = latents.clone()
# maybe return rewards
if return_rewards_per_frame:
@ -1296,7 +1298,7 @@ class DynamicsModel(Module):
for step in range(num_steps):
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8)
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
@ -1326,6 +1328,8 @@ class DynamicsModel(Module):
noised_latent += flow * (step_size / self.max_steps)
denoised_latent = noised_latent # it is now denoised
# take care of the rewards by predicting on the agent token embedding on the last denoising step
if return_rewards_per_frame:
@ -1336,7 +1340,11 @@ class DynamicsModel(Module):
# concat the denoised latent
latents = cat((latents, noised_latent), dim = 1)
latents = cat((latents, denoised_latent), dim = 1)
# add new fixed context noise for the temporal consistency
past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
# restore state