reorganize tokenizer to generate video from the dynamics model
This commit is contained in:
parent
7180a8cf43
commit
83ba9a285a
@ -20,8 +20,6 @@ from x_mlps_pytorch.ensemble import Ensemble
|
||||
|
||||
from assoc_scan import AssocScan
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
# ein related
|
||||
|
||||
# b - batch
|
||||
@ -710,6 +708,8 @@ class VideoTokenizer(Module):
|
||||
dim,
|
||||
dim_latent,
|
||||
patch_size,
|
||||
image_height = None,
|
||||
image_width = None,
|
||||
num_latent_tokens = 4,
|
||||
encoder_depth = 4,
|
||||
decoder_depth = 4,
|
||||
@ -795,6 +795,9 @@ class VideoTokenizer(Module):
|
||||
|
||||
# decoder
|
||||
|
||||
self.image_height = image_height
|
||||
self.image_width = image_width
|
||||
|
||||
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
||||
|
||||
self.to_decoder_pos_emb = create_mlp(
|
||||
@ -825,6 +828,10 @@ class VideoTokenizer(Module):
|
||||
if self.has_lpips_loss:
|
||||
self.lpips = LPIPSLoss(lpips_loss_network)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.zero.device
|
||||
|
||||
@torch.no_grad()
|
||||
def tokenize(
|
||||
self,
|
||||
@ -833,6 +840,102 @@ class VideoTokenizer(Module):
|
||||
self.eval()
|
||||
return self.forward(video, return_latents = True)
|
||||
|
||||
def get_rotary_pos_emb(
|
||||
self,
|
||||
time,
|
||||
num_patch_height,
|
||||
num_patch_width
|
||||
):
|
||||
device = self.device
|
||||
|
||||
positions = stack(torch.meshgrid(
|
||||
arange(time, device = device),
|
||||
arange(num_patch_height, device = device),
|
||||
arange(num_patch_width, device = device)
|
||||
), dim = -1)
|
||||
|
||||
positions = rearrange(positions, 't h w p -> t (h w) p')
|
||||
|
||||
# give the latents an out of bounds position and assume the network will figure it out
|
||||
|
||||
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||
|
||||
positions = rearrange(positions, 't hw p -> (t hw) p')
|
||||
|
||||
return self.spacetime_rotary(positions)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
latents, # (b t n d)
|
||||
height = None,
|
||||
width = None,
|
||||
rotary_pos_emb = None
|
||||
): # (b c t h w)
|
||||
|
||||
height = default(height, self.image_height)
|
||||
width = default(width, self.image_width)
|
||||
|
||||
assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
|
||||
|
||||
batch, time, device = *latents.shape[:2], latents.device
|
||||
|
||||
use_flex = latents.is_cuda and exists(flex_attention)
|
||||
|
||||
num_patch_height = height // self.patch_size
|
||||
num_patch_width = width // self.patch_size
|
||||
|
||||
if not exists(rotary_pos_emb):
|
||||
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
||||
|
||||
# latents to tokens
|
||||
|
||||
latent_tokens = self.latents_to_decoder(latents)
|
||||
|
||||
# generate decoder positional embedding and concat the latent token
|
||||
|
||||
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
||||
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
||||
|
||||
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
||||
|
||||
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
||||
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
||||
|
||||
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
||||
|
||||
# pack time
|
||||
|
||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||
|
||||
seq_len = tokens.shape[-2]
|
||||
|
||||
# decoder attend
|
||||
|
||||
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
|
||||
|
||||
# decoder attention
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.decoder_norm(tokens)
|
||||
|
||||
# unpack time
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
|
||||
# unpack latents
|
||||
|
||||
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
||||
|
||||
# project back to patches
|
||||
|
||||
recon_video = self.tokens_to_patch(tokens)
|
||||
|
||||
return recon_video
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
@ -855,21 +958,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
# rotary positions
|
||||
|
||||
positions = stack(torch.meshgrid(
|
||||
arange(time, device = device),
|
||||
arange(num_patch_height, device = device),
|
||||
arange(num_patch_width, device = device)
|
||||
), dim = -1)
|
||||
|
||||
positions = rearrange(positions, 't h w p -> t (h w) p')
|
||||
|
||||
# give the latents an out of bounds position and assume the network will figure it out
|
||||
|
||||
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||
|
||||
positions = rearrange(positions, 't hw p -> (t hw) p')
|
||||
|
||||
rotary_pos_emb = self.spacetime_rotary(positions)
|
||||
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
||||
|
||||
# masking
|
||||
|
||||
@ -941,48 +1030,7 @@ class VideoTokenizer(Module):
|
||||
if return_latents:
|
||||
return latents
|
||||
|
||||
latent_tokens = self.latents_to_decoder(latents)
|
||||
|
||||
# generate decoder positional embedding and concat the latent token
|
||||
|
||||
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
||||
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
||||
|
||||
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
||||
|
||||
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
||||
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
||||
|
||||
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
||||
|
||||
# pack time
|
||||
|
||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||
|
||||
# decoder attend
|
||||
|
||||
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
|
||||
|
||||
# decoder attention
|
||||
|
||||
for attn, ff in self.decoder_layers:
|
||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
||||
|
||||
tokens = ff(tokens) + tokens
|
||||
|
||||
tokens = self.decoder_norm(tokens)
|
||||
|
||||
# unpack time
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
|
||||
# unpack latents
|
||||
|
||||
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
||||
|
||||
# project back to patches
|
||||
|
||||
recon_video = self.tokens_to_patch(tokens)
|
||||
recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb)
|
||||
|
||||
# losses
|
||||
|
||||
@ -1180,6 +1228,18 @@ class DynamicsModel(Module):
|
||||
|
||||
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||
|
||||
def generate(
|
||||
self,
|
||||
num_frames,
|
||||
num_steps = 4,
|
||||
image_height = None,
|
||||
image_width = None
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
assert log(num_steps).is_integer(), f'number of steps must be a power of 2'
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -40,6 +40,9 @@ def test_e2e(
|
||||
latents = tokenizer(video, return_latents = True)
|
||||
assert latents.shape[-1] == 16
|
||||
|
||||
recon = tokenizer.decode(latents, 256, 256)
|
||||
assert recon.shape == video.shape
|
||||
|
||||
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
||||
|
||||
dynamics = DynamicsModel(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user