From 4a5465eeb67fe455a0253a1c88abcc30a855bd97 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 09:17:53 -0700 Subject: [PATCH] fix ci --- tests/test_dreamer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index d949f24..b623759 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -447,10 +447,10 @@ def test_tokenizer_trainer(): class MockDataset(Dataset): def __len__(self): - return 4 + return 2 def __getitem__(self, idx): - return torch.randn(3, 16, 256, 256) + return torch.randn(3, 2, 64, 64) dataset = MockDataset() @@ -468,7 +468,7 @@ def test_tokenizer_trainer(): tokenizer, dataset = dataset, num_train_steps = 1, - batch_size = 2, + batch_size = 1, cpu = True ) @@ -483,7 +483,7 @@ def test_cache_generate(): max_steps = 64, num_tasks = 4, num_latent_tokens = 4, - depth = 4, + depth = 1, num_spatial_tokens = 1, pred_orig_latent = True, num_discrete_actions = 4,