dreamer4/tests/test_dreamer.py
2025-10-02 06:11:04 -07:00

29 lines
678 B
Python

import pytest
import torch
def test_attn():
from dreamer4.dreamer4 import Attention
x = torch.randn(1, 1024, 512)
attn = Attention(512)
assert attn(x).shape == x.shape
def test_ff():
from dreamer4.dreamer4 import SwiGLUFeedforward
x = torch.randn(1, 1024, 512)
ff = SwiGLUFeedforward(512)
assert ff(x).shape == x.shape
def test_tokenizer():
from dreamer4.dreamer4 import VideoTokenizer
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16)
x = torch.randn(1, 3, 16, 256, 256)
loss = tokenizer(x)
assert loss.numel() == 1
latents = tokenizer(x, return_latents = True)
assert latents.shape[-1] == 32