From 0285bba82148d0cc555458c73505e92390fa988e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 2 Oct 2025 06:11:04 -0700 Subject: [PATCH] flesh out tokenizer even more --- dreamer4/dreamer4.py | 79 +++++++++++++++++++++++++++++++++++++++++-- tests/test_dreamer.py | 3 ++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 300c764..7c46ff6 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -24,7 +24,7 @@ from accelerate import Accelerator # vh, vw - video height and width import einx -from einops import einsum, rearrange, repeat, reduce +from einops import einsum, rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange # flex attention - but will make sure it works if it is not available @@ -51,9 +51,21 @@ def exists(v): def default(v, d): return v if exists(v) else d +def first(arr): + return arr[0] + def divisible_by(num, den): return (num % den) == 0 +def pack_one(t, pattern): + packed, packed_shape = pack([t], pattern) + + def inverse(out, inv_pattern = None): + inv_pattern = default(inv_pattern, pattern) + return first(unpack(out, packed_shape, inv_pattern)) + + return packed, inverse + def l2norm(t): return F.normalize(t, dim = -1, p = 2) @@ -337,12 +349,21 @@ class VideoTokenizer(Module): dim, dim_latent, patch_size, + encoder_depth = 4, + decoder_depth = 4, + attn_kwargs: dict = dict(), + ff_kwargs: dict = dict(), channels = 3 ): super().__init__() self.patch_size = patch_size + # special tokens + + self.latent_token = Parameter(torch.randn(dim) * 1e-2) + self.mask_token = Parameter(torch.randn(dim) * 1e-2) + # patch and unpatch dim_patch = channels * patch_size ** 2 @@ -361,7 +382,14 @@ class VideoTokenizer(Module): encoder_layers = [] + for _ in range(encoder_depth): + encoder_layers.append(ModuleList([ + Attention(dim = dim, **attn_kwargs), + SwiGLUFeedforward(dim = dim, **ff_kwargs) + ])) + self.encoder_layers = ModuleList(encoder_layers) + self.encoder_norm = RMSNorm(dim) # latents @@ -376,7 +404,14 @@ class VideoTokenizer(Module): decoder_layers = [] + for _ in range(decoder_depth): + decoder_layers.append(ModuleList([ + Attention(dim = dim, **attn_kwargs), + SwiGLUFeedforward(dim = dim, **ff_kwargs) + ])) + self.decoder_layers = ModuleList(decoder_layers) + self.decoder_norm = RMSNorm(dim) def forward( self, @@ -393,6 +428,26 @@ class VideoTokenizer(Module): tokens = self.patch_to_tokens(video) + tokens, inverse_pack_space = pack_one(tokens, 'b t * d') + + # add the latent + + latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1]) + + tokens = cat((tokens, latents), dim = -2) + + # pack time + + tokens, inverse_pack_time = pack_one(tokens, 'b * d') + + # encoder + + for attn, ff in self.encoder_layers: + tokens = attn(tokens) + tokens + tokens = ff(tokens) + tokens + + tokens = self.encoder_norm(tokens) + # latent bottleneck latents = self.encoded_to_latents(tokens) @@ -402,7 +457,27 @@ class VideoTokenizer(Module): tokens = self.latents_to_decoder(latents) - # from tokens back to video + # decoder + + for attn, ff in self.decoder_layers: + tokens = attn(tokens) + tokens + tokens = ff(tokens) + tokens + + tokens = self.decoder_norm(tokens) + + # unpack time + + tokens = inverse_pack_time(tokens) + + # excise latents + + tokens = tokens[..., :-1, :] + + # unpack space + + tokens = inverse_pack_space(tokens) + + # project back to patches recon_video = self.tokens_to_patch(tokens) diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 78f5102..9085b48 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -24,3 +24,6 @@ def test_tokenizer(): loss = tokenizer(x) assert loss.numel() == 1 + + latents = tokenizer(x, return_latents = True) + assert latents.shape[-1] == 32 \ No newline at end of file