validate that we can generate multiple video streams for robotics use-case
This commit is contained in:
parent
4ce82f34df
commit
77a40e8701
@ -702,10 +702,11 @@ def test_proprioception(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if num_video_views > 1:
|
if num_video_views > 1:
|
||||||
video = torch.randn(2, num_video_views, 3, 10, 256, 256)
|
video_shape = (2, num_video_views, 3, 10, 256, 256)
|
||||||
else:
|
else:
|
||||||
video = torch.randn(2, 3, 10, 256, 256)
|
video_shape = (2, 3, 10, 256, 256)
|
||||||
|
|
||||||
|
video = torch.randn(*video_shape)
|
||||||
rewards = torch.randn(2, 10)
|
rewards = torch.randn(2, 10)
|
||||||
proprio = torch.randn(2, 10, 21)
|
proprio = torch.randn(2, 10, 21)
|
||||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||||
@ -722,8 +723,10 @@ def test_proprioception(
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
generations = dynamics.generate(
|
generations = dynamics.generate(
|
||||||
4,
|
10,
|
||||||
batch_size = 2,
|
batch_size = 2,
|
||||||
|
return_decoded_video = True
|
||||||
)
|
)
|
||||||
|
|
||||||
assert exists(generations.proprio)
|
assert exists(generations.proprio)
|
||||||
|
assert generations.video.shape == video_shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user