start setting up tokenizer
This commit is contained in:
parent
67519a451d
commit
31c4aa28c7
@ -4,6 +4,7 @@ import math
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||||
@ -18,6 +19,9 @@ from accelerate import Accelerator
|
|||||||
# d - feature dimension
|
# d - feature dimension
|
||||||
# f - frequencies (rotary)
|
# f - frequencies (rotary)
|
||||||
# p - positions (3 for spacetime in this work)
|
# p - positions (3 for spacetime in this work)
|
||||||
|
# t - time
|
||||||
|
# vc - video channels
|
||||||
|
# vh, vw - video height and width
|
||||||
|
|
||||||
import einx
|
import einx
|
||||||
from einops import einsum, rearrange, repeat, reduce
|
from einops import einsum, rearrange, repeat, reduce
|
||||||
@ -331,10 +335,36 @@ class VideoTokenizer(Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_latent
|
dim_latent,
|
||||||
|
patch_size,
|
||||||
|
channels = 3
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
# patch and unpatch
|
||||||
|
|
||||||
|
dim_patch = channels * patch_size ** 2
|
||||||
|
|
||||||
|
self.patch_to_tokens = Sequential(
|
||||||
|
Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||||
|
Linear(dim_patch, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tokens_to_patch = Sequential(
|
||||||
|
Linear(dim, dim_patch),
|
||||||
|
Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
|
||||||
|
encoder_layers = []
|
||||||
|
|
||||||
|
self.encoder_layers = ModuleList(encoder_layers)
|
||||||
|
|
||||||
|
# latents
|
||||||
|
|
||||||
self.encoded_to_latents = Sequential(
|
self.encoded_to_latents = Sequential(
|
||||||
LinearNoBias(dim, dim_latent),
|
LinearNoBias(dim, dim_latent),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
@ -342,6 +372,46 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
self.latents_to_decoder = LinearNoBias(dim_latent, dim)
|
self.latents_to_decoder = LinearNoBias(dim_latent, dim)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
|
||||||
|
decoder_layers = []
|
||||||
|
|
||||||
|
self.decoder_layers = ModuleList(decoder_layers)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
video, # (b c t h w)
|
||||||
|
return_latents = False
|
||||||
|
):
|
||||||
|
patch_size = self.patch_size
|
||||||
|
|
||||||
|
*_, height, width = video.shape
|
||||||
|
|
||||||
|
assert divisible_by(height, patch_size) and divisible_by(width, patch_size)
|
||||||
|
|
||||||
|
# to tokens
|
||||||
|
|
||||||
|
tokens = self.patch_to_tokens(video)
|
||||||
|
|
||||||
|
# latent bottleneck
|
||||||
|
|
||||||
|
latents = self.encoded_to_latents(tokens)
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
tokens = self.latents_to_decoder(latents)
|
||||||
|
|
||||||
|
# from tokens back to video
|
||||||
|
|
||||||
|
recon_video = self.tokens_to_patch(tokens)
|
||||||
|
|
||||||
|
# losses
|
||||||
|
|
||||||
|
recon_loss = F.mse_loss(video, recon_video)
|
||||||
|
|
||||||
|
return recon_loss
|
||||||
|
|
||||||
# dynamics model
|
# dynamics model
|
||||||
|
|
||||||
class DynamicsModel(Module):
|
class DynamicsModel(Module):
|
||||||
|
|||||||
@ -15,3 +15,12 @@ def test_ff():
|
|||||||
ff = SwiGLUFeedforward(512)
|
ff = SwiGLUFeedforward(512)
|
||||||
|
|
||||||
assert ff(x).shape == x.shape
|
assert ff(x).shape == x.shape
|
||||||
|
|
||||||
|
def test_tokenizer():
|
||||||
|
from dreamer4.dreamer4 import VideoTokenizer
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16)
|
||||||
|
x = torch.randn(1, 3, 16, 256, 256)
|
||||||
|
|
||||||
|
loss = tokenizer(x)
|
||||||
|
assert loss.numel() == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user