This commit is contained in:
lucidrains 2025-10-21 09:17:53 -07:00
parent b34128d3d0
commit 4a5465eeb6

View File

@ -447,10 +447,10 @@ def test_tokenizer_trainer():
class MockDataset(Dataset):
def __len__(self):
return 4
return 2
def __getitem__(self, idx):
return torch.randn(3, 16, 256, 256)
return torch.randn(3, 2, 64, 64)
dataset = MockDataset()
@ -468,7 +468,7 @@ def test_tokenizer_trainer():
tokenizer,
dataset = dataset,
num_train_steps = 1,
batch_size = 2,
batch_size = 1,
cpu = True
)
@ -483,7 +483,7 @@ def test_cache_generate():
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 4,
depth = 4,
depth = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,