flesh out tokenizer even more
This commit is contained in:
parent
31c4aa28c7
commit
0285bba821
@ -24,7 +24,7 @@ from accelerate import Accelerator
|
|||||||
# vh, vw - video height and width
|
# vh, vw - video height and width
|
||||||
|
|
||||||
import einx
|
import einx
|
||||||
from einops import einsum, rearrange, repeat, reduce
|
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
# flex attention - but will make sure it works if it is not available
|
# flex attention - but will make sure it works if it is not available
|
||||||
@ -51,9 +51,21 @@ def exists(v):
|
|||||||
def default(v, d):
|
def default(v, d):
|
||||||
return v if exists(v) else d
|
return v if exists(v) else d
|
||||||
|
|
||||||
|
def first(arr):
|
||||||
|
return arr[0]
|
||||||
|
|
||||||
def divisible_by(num, den):
|
def divisible_by(num, den):
|
||||||
return (num % den) == 0
|
return (num % den) == 0
|
||||||
|
|
||||||
|
def pack_one(t, pattern):
|
||||||
|
packed, packed_shape = pack([t], pattern)
|
||||||
|
|
||||||
|
def inverse(out, inv_pattern = None):
|
||||||
|
inv_pattern = default(inv_pattern, pattern)
|
||||||
|
return first(unpack(out, packed_shape, inv_pattern))
|
||||||
|
|
||||||
|
return packed, inverse
|
||||||
|
|
||||||
def l2norm(t):
|
def l2norm(t):
|
||||||
return F.normalize(t, dim = -1, p = 2)
|
return F.normalize(t, dim = -1, p = 2)
|
||||||
|
|
||||||
@ -337,12 +349,21 @@ class VideoTokenizer(Module):
|
|||||||
dim,
|
dim,
|
||||||
dim_latent,
|
dim_latent,
|
||||||
patch_size,
|
patch_size,
|
||||||
|
encoder_depth = 4,
|
||||||
|
decoder_depth = 4,
|
||||||
|
attn_kwargs: dict = dict(),
|
||||||
|
ff_kwargs: dict = dict(),
|
||||||
channels = 3
|
channels = 3
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
# special tokens
|
||||||
|
|
||||||
|
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
self.mask_token = Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
|
||||||
# patch and unpatch
|
# patch and unpatch
|
||||||
|
|
||||||
dim_patch = channels * patch_size ** 2
|
dim_patch = channels * patch_size ** 2
|
||||||
@ -361,7 +382,14 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
encoder_layers = []
|
encoder_layers = []
|
||||||
|
|
||||||
|
for _ in range(encoder_depth):
|
||||||
|
encoder_layers.append(ModuleList([
|
||||||
|
Attention(dim = dim, **attn_kwargs),
|
||||||
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||||
|
]))
|
||||||
|
|
||||||
self.encoder_layers = ModuleList(encoder_layers)
|
self.encoder_layers = ModuleList(encoder_layers)
|
||||||
|
self.encoder_norm = RMSNorm(dim)
|
||||||
|
|
||||||
# latents
|
# latents
|
||||||
|
|
||||||
@ -376,7 +404,14 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
decoder_layers = []
|
decoder_layers = []
|
||||||
|
|
||||||
|
for _ in range(decoder_depth):
|
||||||
|
decoder_layers.append(ModuleList([
|
||||||
|
Attention(dim = dim, **attn_kwargs),
|
||||||
|
SwiGLUFeedforward(dim = dim, **ff_kwargs)
|
||||||
|
]))
|
||||||
|
|
||||||
self.decoder_layers = ModuleList(decoder_layers)
|
self.decoder_layers = ModuleList(decoder_layers)
|
||||||
|
self.decoder_norm = RMSNorm(dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -393,6 +428,26 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
tokens = self.patch_to_tokens(video)
|
tokens = self.patch_to_tokens(video)
|
||||||
|
|
||||||
|
tokens, inverse_pack_space = pack_one(tokens, 'b t * d')
|
||||||
|
|
||||||
|
# add the latent
|
||||||
|
|
||||||
|
latents = repeat(self.latent_token, 'd -> b t 1 d', b = tokens.shape[0], t = tokens.shape[1])
|
||||||
|
|
||||||
|
tokens = cat((tokens, latents), dim = -2)
|
||||||
|
|
||||||
|
# pack time
|
||||||
|
|
||||||
|
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
|
||||||
|
for attn, ff in self.encoder_layers:
|
||||||
|
tokens = attn(tokens) + tokens
|
||||||
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
|
tokens = self.encoder_norm(tokens)
|
||||||
|
|
||||||
# latent bottleneck
|
# latent bottleneck
|
||||||
|
|
||||||
latents = self.encoded_to_latents(tokens)
|
latents = self.encoded_to_latents(tokens)
|
||||||
@ -402,7 +457,27 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
tokens = self.latents_to_decoder(latents)
|
tokens = self.latents_to_decoder(latents)
|
||||||
|
|
||||||
# from tokens back to video
|
# decoder
|
||||||
|
|
||||||
|
for attn, ff in self.decoder_layers:
|
||||||
|
tokens = attn(tokens) + tokens
|
||||||
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
|
tokens = self.decoder_norm(tokens)
|
||||||
|
|
||||||
|
# unpack time
|
||||||
|
|
||||||
|
tokens = inverse_pack_time(tokens)
|
||||||
|
|
||||||
|
# excise latents
|
||||||
|
|
||||||
|
tokens = tokens[..., :-1, :]
|
||||||
|
|
||||||
|
# unpack space
|
||||||
|
|
||||||
|
tokens = inverse_pack_space(tokens)
|
||||||
|
|
||||||
|
# project back to patches
|
||||||
|
|
||||||
recon_video = self.tokens_to_patch(tokens)
|
recon_video = self.tokens_to_patch(tokens)
|
||||||
|
|
||||||
|
|||||||
@ -24,3 +24,6 @@ def test_tokenizer():
|
|||||||
|
|
||||||
loss = tokenizer(x)
|
loss = tokenizer(x)
|
||||||
assert loss.numel() == 1
|
assert loss.numel() == 1
|
||||||
|
|
||||||
|
latents = tokenizer(x, return_latents = True)
|
||||||
|
assert latents.shape[-1] == 32
|
||||||
Loading…
x
Reference in New Issue
Block a user