From a8e14f4b7c268944d5907684d056e048739800a0 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 7 Oct 2025 08:09:33 -0700 Subject: [PATCH] oops --- dreamer4/dreamer4.py | 2 +- tests/test_dreamer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 71d8208..1fe95d4 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1279,7 +1279,7 @@ class DynamicsModel(Module): noised_latent = torch.randn((batch_size, 1, *latent_shape), device = self.device) for step in range(num_steps): - signal_levels = torch.full((batch_size, 1), step * step_size, device = self.device) + signal_levels = torch.full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) noised_latent_with_context, pack_context_shape = pack((latents, noised_latent), 'b * n d') diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 05299b3..d5e735d 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -67,7 +67,7 @@ def test_e2e( signal_levels = step_sizes_log2 = None if signal_and_step_passed_in: - signal_levels = torch.randint(0, 500, (2, 4)) + signal_levels = torch.randint(0, 64, (2, 4)) step_sizes_log2 = torch.randint(1, 6, (2,)) if dynamics_with_video_input: