diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index f843f1a..300c764 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -4,6 +4,7 @@ import math from functools import partial import torch +from torch import nn import torch.nn.functional as F from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity from torch import cat, stack, arange, tensor, Tensor, is_tensor @@ -18,6 +19,9 @@ from accelerate import Accelerator # d - feature dimension # f - frequencies (rotary) # p - positions (3 for spacetime in this work) +# t - time +# vc - video channels +# vh, vw - video height and width import einx from einops import einsum, rearrange, repeat, reduce @@ -331,10 +335,36 @@ class VideoTokenizer(Module): def __init__( self, dim, - dim_latent + dim_latent, + patch_size, + channels = 3 ): super().__init__() + self.patch_size = patch_size + + # patch and unpatch + + dim_patch = channels * patch_size ** 2 + + self.patch_to_tokens = Sequential( + Rearrange('b c t (h p1) (w p2) -> b t h w (p1 p2 c)', p1 = patch_size, p2 = patch_size), + Linear(dim_patch, dim) + ) + + self.tokens_to_patch = Sequential( + Linear(dim, dim_patch), + Rearrange('b t h w (p1 p2 c) -> b c t (h p1) (w p2)', p1 = patch_size, p2 = patch_size), + ) + + # encoder + + encoder_layers = [] + + self.encoder_layers = ModuleList(encoder_layers) + + # latents + self.encoded_to_latents = Sequential( LinearNoBias(dim, dim_latent), nn.Tanh(), @@ -342,6 +372,46 @@ class VideoTokenizer(Module): self.latents_to_decoder = LinearNoBias(dim_latent, dim) + # decoder + + decoder_layers = [] + + self.decoder_layers = ModuleList(decoder_layers) + + def forward( + self, + video, # (b c t h w) + return_latents = False + ): + patch_size = self.patch_size + + *_, height, width = video.shape + + assert divisible_by(height, patch_size) and divisible_by(width, patch_size) + + # to tokens + + tokens = self.patch_to_tokens(video) + + # latent bottleneck + + latents = self.encoded_to_latents(tokens) + + if return_latents: + return latents + + tokens = self.latents_to_decoder(latents) + + # from tokens back to video + + recon_video = self.tokens_to_patch(tokens) + + # losses + + recon_loss = F.mse_loss(video, recon_video) + + return recon_loss + # dynamics model class DynamicsModel(Module): diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 701f452..78f5102 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -15,3 +15,12 @@ def test_ff(): ff = SwiGLUFeedforward(512) assert ff(x).shape == x.shape + +def test_tokenizer(): + from dreamer4.dreamer4 import VideoTokenizer + + tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 16) + x = torch.randn(1, 3, 16, 256, 256) + + loss = tokenizer(x) + assert loss.numel() == 1