validate that we can generate multiple video streams for robotics use-case

This commit is contained in:
lucidrains 2025-10-25 09:23:07 -07:00
parent 4ce82f34df
commit 77a40e8701

View File

@ -702,10 +702,11 @@ def test_proprioception(
)
if num_video_views > 1:
video = torch.randn(2, num_video_views, 3, 10, 256, 256)
video_shape = (2, num_video_views, 3, 10, 256, 256)
else:
video = torch.randn(2, 3, 10, 256, 256)
video_shape = (2, 3, 10, 256, 256)
video = torch.randn(*video_shape)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
@ -722,8 +723,10 @@ def test_proprioception(
loss.backward()
generations = dynamics.generate(
4,
10,
batch_size = 2,
return_decoded_video = True
)
assert exists(generations.proprio)
assert generations.video.shape == video_shape