diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index a208b66..d8ffc1e 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -702,10 +702,11 @@ def test_proprioception( ) if num_video_views > 1: - video = torch.randn(2, num_video_views, 3, 10, 256, 256) + video_shape = (2, num_video_views, 3, 10, 256, 256) else: - video = torch.randn(2, 3, 10, 256, 256) + video_shape = (2, 3, 10, 256, 256) + video = torch.randn(*video_shape) rewards = torch.randn(2, 10) proprio = torch.randn(2, 10, 21) discrete_actions = torch.randint(0, 4, (2, 10, 1)) @@ -722,8 +723,10 @@ def test_proprioception( loss.backward() generations = dynamics.generate( - 4, + 10, batch_size = 2, + return_decoded_video = True ) assert exists(generations.proprio) + assert generations.video.shape == video_shape