reorganize tokenizer to generate video from the dynamics model

This commit is contained in:
lucidrains 2025-10-06 11:37:45 -07:00
parent 7180a8cf43
commit 83ba9a285a
2 changed files with 122 additions and 59 deletions

View File

@ -20,8 +20,6 @@ from x_mlps_pytorch.ensemble import Ensemble
from assoc_scan import AssocScan
from accelerate import Accelerator
# ein related
# b - batch
@ -710,6 +708,8 @@ class VideoTokenizer(Module):
dim,
dim_latent,
patch_size,
image_height = None,
image_width = None,
num_latent_tokens = 4,
encoder_depth = 4,
decoder_depth = 4,
@ -795,6 +795,9 @@ class VideoTokenizer(Module):
# decoder
self.image_height = image_height
self.image_width = image_width
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
self.to_decoder_pos_emb = create_mlp(
@ -825,6 +828,10 @@ class VideoTokenizer(Module):
if self.has_lpips_loss:
self.lpips = LPIPSLoss(lpips_loss_network)
@property
def device(self):
return self.zero.device
@torch.no_grad()
def tokenize(
self,
@ -833,6 +840,102 @@ class VideoTokenizer(Module):
self.eval()
return self.forward(video, return_latents = True)
def get_rotary_pos_emb(
self,
time,
num_patch_height,
num_patch_width
):
device = self.device
positions = stack(torch.meshgrid(
arange(time, device = device),
arange(num_patch_height, device = device),
arange(num_patch_width, device = device)
), dim = -1)
positions = rearrange(positions, 't h w p -> t (h w) p')
# give the latents an out of bounds position and assume the network will figure it out
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
positions = rearrange(positions, 't hw p -> (t hw) p')
return self.spacetime_rotary(positions)
def decode(
self,
latents, # (b t n d)
height = None,
width = None,
rotary_pos_emb = None
): # (b c t h w)
height = default(height, self.image_height)
width = default(width, self.image_width)
assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
batch, time, device = *latents.shape[:2], latents.device
use_flex = latents.is_cuda and exists(flex_attention)
num_patch_height = height // self.patch_size
num_patch_width = width // self.patch_size
if not exists(rotary_pos_emb):
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
# latents to tokens
latent_tokens = self.latents_to_decoder(latents)
# generate decoder positional embedding and concat the latent token
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
# pack time
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
seq_len = tokens.shape[-2]
# decoder attend
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
# decoder attention
for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
tokens = ff(tokens) + tokens
tokens = self.decoder_norm(tokens)
# unpack time
tokens = inverse_pack_time(tokens)
# unpack latents
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
# project back to patches
recon_video = self.tokens_to_patch(tokens)
return recon_video
def forward(
self,
video, # (b c t h w)
@ -855,21 +958,7 @@ class VideoTokenizer(Module):
# rotary positions
positions = stack(torch.meshgrid(
arange(time, device = device),
arange(num_patch_height, device = device),
arange(num_patch_width, device = device)
), dim = -1)
positions = rearrange(positions, 't h w p -> t (h w) p')
# give the latents an out of bounds position and assume the network will figure it out
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
positions = rearrange(positions, 't hw p -> (t hw) p')
rotary_pos_emb = self.spacetime_rotary(positions)
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
# masking
@ -941,48 +1030,7 @@ class VideoTokenizer(Module):
if return_latents:
return latents
latent_tokens = self.latents_to_decoder(latents)
# generate decoder positional embedding and concat the latent token
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
# pack time
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
# decoder attend
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
# decoder attention
for attn, ff in self.decoder_layers:
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
tokens = ff(tokens) + tokens
tokens = self.decoder_norm(tokens)
# unpack time
tokens = inverse_pack_time(tokens)
# unpack latents
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
# project back to patches
recon_video = self.tokens_to_patch(tokens)
recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb)
# losses
@ -1180,6 +1228,18 @@ class DynamicsModel(Module):
return list(set(params) - set(self.video_tokenizer.parameters()))
def generate(
self,
num_frames,
num_steps = 4,
image_height = None,
image_width = None
): # (b t n d) | (b c t h w)
assert log(num_steps).is_integer(), f'number of steps must be a power of 2'
raise NotImplementedError
def forward(
self,
*,

View File

@ -40,6 +40,9 @@ def test_e2e(
latents = tokenizer(video, return_latents = True)
assert latents.shape[-1] == 16
recon = tokenizer.decode(latents, 256, 256)
assert recon.shape == video.shape
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
dynamics = DynamicsModel(