validate that we can generate multiple video streams for robotics use-case

This commit is contained in:
lucidrains 2025-10-25 09:23:07 -07:00
parent 4ce82f34df
commit 77a40e8701

View File

@ -702,10 +702,11 @@ def test_proprioception(
) )
if num_video_views > 1: if num_video_views > 1:
video = torch.randn(2, num_video_views, 3, 10, 256, 256) video_shape = (2, num_video_views, 3, 10, 256, 256)
else: else:
video = torch.randn(2, 3, 10, 256, 256) video_shape = (2, 3, 10, 256, 256)
video = torch.randn(*video_shape)
rewards = torch.randn(2, 10) rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21) proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1)) discrete_actions = torch.randint(0, 4, (2, 10, 1))
@ -722,8 +723,10 @@ def test_proprioception(
loss.backward() loss.backward()
generations = dynamics.generate( generations = dynamics.generate(
4, 10,
batch_size = 2, batch_size = 2,
return_decoded_video = True
) )
assert exists(generations.proprio) assert exists(generations.proprio)
assert generations.video.shape == video_shape