From 77a40e8701b4ead35fe698693d3d353fb5a36b2c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 25 Oct 2025 09:23:07 -0700 Subject: [PATCH] validate that we can generate multiple video streams for robotics use-case --- tests/test_dreamer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index a208b66..d8ffc1e 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -702,10 +702,11 @@ def test_proprioception( ) if num_video_views > 1: - video = torch.randn(2, num_video_views, 3, 10, 256, 256) + video_shape = (2, num_video_views, 3, 10, 256, 256) else: - video = torch.randn(2, 3, 10, 256, 256) + video_shape = (2, 3, 10, 256, 256) + video = torch.randn(*video_shape) rewards = torch.randn(2, 10) proprio = torch.randn(2, 10, 21) discrete_actions = torch.randint(0, 4, (2, 10, 1)) @@ -722,8 +723,10 @@ def test_proprioception( loss.backward() generations = dynamics.generate( - 4, + 10, batch_size = 2, + return_decoded_video = True ) assert exists(generations.proprio) + assert generations.video.shape == video_shape