This commit is contained in:
lucidrains 2025-10-07 08:09:33 -07:00
parent 1176269927
commit a8e14f4b7c
2 changed files with 2 additions and 2 deletions

View File

@ -1279,7 +1279,7 @@ class DynamicsModel(Module):
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
for step in range(num_steps):
signal_levels = torch.full((batch_size, 1), step * step_size, device = self.device)
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')

View File

@ -67,7 +67,7 @@ def test_e2e(
signal_levels = step_sizes_log2 = None
if signal_and_step_passed_in:
signal_levels = torch.randint(0, 500, (2, 4))
signal_levels = torch.randint(0, 64, (2, 4))
step_sizes_log2 = torch.randint(1, 6, (2,))
if dynamics_with_video_input: