2025-10-01 09:28:25 -07:00
|
|
|
import pytest
|
|
|
|
|
import torch
|
2025-10-01 07:18:18 -07:00
|
|
|
|
2025-10-01 09:28:25 -07:00
|
|
|
def test_attn():
|
|
|
|
|
from dreamer4.dreamer4 import Attention
|
|
|
|
|
|
|
|
|
|
x = torch.randn(1, 1024, 512)
|
|
|
|
|
attn = Attention(512)
|
|
|
|
|
|
|
|
|
|
assert attn(x).shape == x.shape
|
|
|
|
|
|
|
|
|
|
def test_ff():
|
|
|
|
|
from dreamer4.dreamer4 import SwiGLUFeedforward
|
|
|
|
|
x = torch.randn(1, 1024, 512)
|
|
|
|
|
ff = SwiGLUFeedforward(512)
|
|
|
|
|
|
|
|
|
|
assert ff(x).shape == x.shape
|
2025-10-02 05:37:43 -07:00
|
|
|
|
|
|
|
|
def test_tokenizer():
|
|
|
|
|
from dreamer4.dreamer4 import VideoTokenizer
|
|
|
|
|
|
|
|
|
|
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16)
|
|
|
|
|
x = torch.randn(1, 3, 16, 256, 256)
|
|
|
|
|
|
|
|
|
|
loss = tokenizer(x)
|
|
|
|
|
assert loss.numel() == 1
|
2025-10-02 06:11:04 -07:00
|
|
|
|
|
|
|
|
latents = tokenizer(x, return_latents = True)
|
|
|
|
|
assert latents.shape[-1] == 32
|