oops
This commit is contained in:
parent
1176269927
commit
a8e14f4b7c
@ -1279,7 +1279,7 @@ class DynamicsModel(Module):
|
|||||||
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
|
noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device)
|
||||||
|
|
||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
signal_levels = torch.full((batch_size, 1), step * step_size, device = self.device)
|
signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||||
|
|
||||||
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
|
noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d')
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,7 @@ def test_e2e(
|
|||||||
signal_levels = step_sizes_log2 = None
|
signal_levels = step_sizes_log2 = None
|
||||||
|
|
||||||
if signal_and_step_passed_in:
|
if signal_and_step_passed_in:
|
||||||
signal_levels = torch.randint(0, 500, (2, 4))
|
signal_levels = torch.randint(0, 64, (2, 4))
|
||||||
step_sizes_log2 = torch.randint(1, 6, (2,))
|
step_sizes_log2 = torch.randint(1, 6, (2,))
|
||||||
|
|
||||||
if dynamics_with_video_input:
|
if dynamics_with_video_input:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user