fix ci
This commit is contained in:
parent
b34128d3d0
commit
4a5465eeb6
@ -447,10 +447,10 @@ def test_tokenizer_trainer():
|
|||||||
|
|
||||||
class MockDataset(Dataset):
|
class MockDataset(Dataset):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return 4
|
return 2
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return torch.randn(3, 16, 256, 256)
|
return torch.randn(3, 2, 64, 64)
|
||||||
|
|
||||||
dataset = MockDataset()
|
dataset = MockDataset()
|
||||||
|
|
||||||
@ -468,7 +468,7 @@ def test_tokenizer_trainer():
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
dataset = dataset,
|
dataset = dataset,
|
||||||
num_train_steps = 1,
|
num_train_steps = 1,
|
||||||
batch_size = 2,
|
batch_size = 1,
|
||||||
cpu = True
|
cpu = True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -483,7 +483,7 @@ def test_cache_generate():
|
|||||||
max_steps = 64,
|
max_steps = 64,
|
||||||
num_tasks = 4,
|
num_tasks = 4,
|
||||||
num_latent_tokens = 4,
|
num_latent_tokens = 4,
|
||||||
depth = 4,
|
depth = 1,
|
||||||
num_spatial_tokens = 1,
|
num_spatial_tokens = 1,
|
||||||
pred_orig_latent = True,
|
pred_orig_latent = True,
|
||||||
num_discrete_actions = 4,
|
num_discrete_actions = 4,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user