makes more sense for the noise to be fixed
This commit is contained in:
parent
9c56ba0c9d
commit
a50e360502
@ -1281,6 +1281,8 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
||||||
|
|
||||||
|
past_context_noise = latents.clone()
|
||||||
|
|
||||||
# maybe return rewards
|
# maybe return rewards
|
||||||
|
|
||||||
if return_rewards_per_frame:
|
if return_rewards_per_frame:
|
||||||
@ -1296,7 +1298,7 @@ class DynamicsModel(Module):
|
|||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||||
|
|
||||||
noised_context = latents.lerp(randn_like(latents), context_signal_noise) # the paragraph after eq (8)
|
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||||
|
|
||||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||||
|
|
||||||
@ -1326,6 +1328,8 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
noised_latent += flow * (step_size / self.max_steps)
|
noised_latent += flow * (step_size / self.max_steps)
|
||||||
|
|
||||||
|
denoised_latent = noised_latent # it is now denoised
|
||||||
|
|
||||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||||
|
|
||||||
if return_rewards_per_frame:
|
if return_rewards_per_frame:
|
||||||
@ -1336,7 +1340,11 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
# concat the denoised latent
|
# concat the denoised latent
|
||||||
|
|
||||||
latents = cat((latents, noised_latent), dim = 1)
|
latents = cat((latents, denoised_latent), dim = 1)
|
||||||
|
|
||||||
|
# add new fixed context noise for the temporal consistency
|
||||||
|
|
||||||
|
past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
|
||||||
|
|
||||||
# restore state
|
# restore state
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user