2025-10-01 09:28:25 -07:00
|
|
|
import pytest
|
2025-10-02 08:36:00 -07:00
|
|
|
param = pytest.mark.parametrize
|
2025-10-01 09:28:25 -07:00
|
|
|
import torch
|
2025-10-01 07:18:18 -07:00
|
|
|
|
2025-10-02 08:36:00 -07:00
|
|
|
@param('pred_orig_latent', (False, True))
|
|
|
|
|
def test_e2e(
|
|
|
|
|
pred_orig_latent
|
|
|
|
|
):
|
2025-10-02 07:17:58 -07:00
|
|
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
2025-10-01 09:28:25 -07:00
|
|
|
|
2025-10-02 07:17:58 -07:00
|
|
|
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
2025-10-02 07:39:34 -07:00
|
|
|
x = torch.randn(2, 3, 4, 256, 256)
|
2025-10-02 05:37:43 -07:00
|
|
|
|
|
|
|
|
loss = tokenizer(x)
|
|
|
|
|
assert loss.numel() == 1
|
2025-10-02 06:11:04 -07:00
|
|
|
|
|
|
|
|
latents = tokenizer(x, return_latents = True)
|
2025-10-02 07:17:58 -07:00
|
|
|
assert latents.shape[-1] == 32
|
|
|
|
|
|
2025-10-02 08:36:00 -07:00
|
|
|
dynamics = DynamicsModel(512, dim_latent = 32, num_signal_levels = 500, num_step_sizes = 32, pred_orig_latent = pred_orig_latent)
|
2025-10-02 07:39:34 -07:00
|
|
|
|
|
|
|
|
signal_levels = torch.randint(0, 500, (2, 4))
|
|
|
|
|
step_sizes = torch.randint(0, 32, (2, 4))
|
|
|
|
|
|
2025-10-02 08:36:00 -07:00
|
|
|
flow_loss = dynamics(latents, signal_levels = signal_levels, step_sizes = step_sizes)
|
|
|
|
|
assert flow_loss.numel() == 1
|