start setting up tokenizer
This commit is contained in:
parent
67519a451d
commit
31c4aa28c7
@ -4,6 +4,7 @@ import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||
@ -18,6 +19,9 @@ from accelerate import Accelerator
|
||||
# d - feature dimension
|
||||
# f - frequencies (rotary)
|
||||
# p - positions (3 for spacetime in this work)
|
||||
# t - time
|
||||
# vc - video channels
|
||||
# vh, vw - video height and width
|
||||
|
||||
import einx
|
||||
from einops import einsum, rearrange, repeat, reduce
|
||||
@ -331,10 +335,36 @@ class VideoTokenizer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_latent
|
||||
dim_latent,
|
||||
patch_size,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# patch and unpatch
|
||||
|
||||
dim_patch = channels * patch_size ** 2
|
||||
|
||||
self.patch_to_tokens = Sequential(
|
||||
Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
Linear(dim_patch, dim)
|
||||
)
|
||||
|
||||
self.tokens_to_patch = Sequential(
|
||||
Linear(dim, dim_patch),
|
||||
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
||||
)
|
||||
|
||||
# encoder
|
||||
|
||||
encoder_layers = []
|
||||
|
||||
self.encoder_layers = ModuleList(encoder_layers)
|
||||
|
||||
# latents
|
||||
|
||||
self.encoded_to_latents = Sequential(
|
||||
LinearNoBias(dim, dim_latent),
|
||||
nn.Tanh(),
|
||||
@ -342,6 +372,46 @@ class VideoTokenizer(Module):
|
||||
|
||||
self.latents_to_decoder = LinearNoBias(dim_latent, dim)
|
||||
|
||||
# decoder
|
||||
|
||||
decoder_layers = []
|
||||
|
||||
self.decoder_layers = ModuleList(decoder_layers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
return_latents = False
|
||||
):
|
||||
patch_size = self.patch_size
|
||||
|
||||
*_, height, width = video.shape
|
||||
|
||||
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
||||
|
||||
# to tokens
|
||||
|
||||
tokens = self.patch_to_tokens(video)
|
||||
|
||||
# latent bottleneck
|
||||
|
||||
latents = self.encoded_to_latents(tokens)
|
||||
|
||||
if return_latents:
|
||||
return latents
|
||||
|
||||
tokens = self.latents_to_decoder(latents)
|
||||
|
||||
# from tokens back to video
|
||||
|
||||
recon_video = self.tokens_to_patch(tokens)
|
||||
|
||||
# losses
|
||||
|
||||
recon_loss = F.mse_loss(video, recon_video)
|
||||
|
||||
return recon_loss
|
||||
|
||||
# dynamics model
|
||||
|
||||
class DynamicsModel(Module):
|
||||
|
||||
@ -15,3 +15,12 @@ def test_ff():
|
||||
ff = SwiGLUFeedforward(512)
|
||||
|
||||
assert ff(x).shape == x.shape
|
||||
|
||||
def test_tokenizer():
|
||||
from dreamer4.dreamer4 import VideoTokenizer
|
||||
|
||||
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16)
|
||||
x = torch.randn(1, 3, 16, 256, 256)
|
||||
|
||||
loss = tokenizer(x)
|
||||
assert loss.numel() == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user