start setting up tokenizer

This commit is contained in:
lucidrains 2025-10-02 05:37:43 -07:00
parent 67519a451d
commit 31c4aa28c7
2 changed files with 80 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import math
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, arange, tensor, Tensor, is_tensor
@ -18,6 +19,9 @@ from accelerate import Accelerator
# d - feature dimension
# f - frequencies (rotary)
# p - positions (3 for spacetime in this work)
# t - time
# vc - video channels
# vh, vw - video height and width
import einx
from einops import einsum, rearrange, repeat, reduce
@ -331,10 +335,36 @@ class VideoTokenizer(Module):
def __init__(
self,
dim,
dim_latent
dim_latent,
patch_size,
channels = 3
):
super().__init__()
self.patch_size = patch_size
# patch and unpatch
dim_patch = channels * patch_size ** 2
self.patch_to_tokens = Sequential(
Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
Linear(dim_patch, dim)
)
self.tokens_to_patch = Sequential(
Linear(dim, dim_patch),
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
)
# encoder
encoder_layers = []
self.encoder_layers = ModuleList(encoder_layers)
# latents
self.encoded_to_latents = Sequential(
LinearNoBias(dim, dim_latent),
nn.Tanh(),
@ -342,6 +372,46 @@ class VideoTokenizer(Module):
self.latents_to_decoder = LinearNoBias(dim_latent, dim)
# decoder
decoder_layers = []
self.decoder_layers = ModuleList(decoder_layers)
def forward(
self,
video, # (b c t h w)
return_latents = False
):
patch_size = self.patch_size
*_, height, width = video.shape
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
# to tokens
tokens = self.patch_to_tokens(video)
# latent bottleneck
latents = self.encoded_to_latents(tokens)
if return_latents:
return latents
tokens = self.latents_to_decoder(latents)
# from tokens back to video
recon_video = self.tokens_to_patch(tokens)
# losses
recon_loss = F.mse_loss(video, recon_video)
return recon_loss
# dynamics model
class DynamicsModel(Module):

View File

@ -15,3 +15,12 @@ def test_ff():
ff = SwiGLUFeedforward(512)
assert ff(x).shape == x.shape
def test_tokenizer():
from dreamer4.dreamer4 import VideoTokenizer
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16)
x = torch.randn(1, 3, 16, 256, 256)
loss = tokenizer(x)
assert loss.numel() == 1