makes more sense for the noise to be fixed
This commit is contained in:
parent
9c56ba0c9d
commit
a50e360502
@ -1281,6 +1281,8 @@ class DynamicsModel(Module):
|
||||
|
||||
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
||||
|
||||
past_context_noise = latents.clone()
|
||||
|
||||
# maybe return rewards
|
||||
|
||||
if return_rewards_per_frame:
|
||||
@ -1296,7 +1298,7 @@ class DynamicsModel(Module):
|
||||
for step in range(num_steps):
|
||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
|
||||
noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
||||
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||
|
||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||
|
||||
@ -1326,6 +1328,8 @@ class DynamicsModel(Module):
|
||||
|
||||
noised_latent += flow * (step_size / self.max_steps)
|
||||
|
||||
denoised_latent = noised_latent # it is now denoised
|
||||
|
||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||
|
||||
if return_rewards_per_frame:
|
||||
@ -1336,7 +1340,11 @@ class DynamicsModel(Module):
|
||||
|
||||
# concat the denoised latent
|
||||
|
||||
latents = cat((latents, noised_latent), dim = 1)
|
||||
latents = cat((latents, denoised_latent), dim = 1)
|
||||
|
||||
# add new fixed context noise for the temporal consistency
|
||||
|
||||
past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
|
||||
|
||||
# restore state
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user