reorganize tokenizer to generate video from the dynamics model
This commit is contained in:
parent
7180a8cf43
commit
83ba9a285a
@ -20,8 +20,6 @@ from x_mlps_pytorch.ensemble import Ensemble
|
|||||||
|
|
||||||
from assoc_scan import AssocScan
|
from assoc_scan import AssocScan
|
||||||
|
|
||||||
from accelerate import Accelerator
|
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
|
|
||||||
# b - batch
|
# b - batch
|
||||||
@ -710,6 +708,8 @@ class VideoTokenizer(Module):
|
|||||||
dim,
|
dim,
|
||||||
dim_latent,
|
dim_latent,
|
||||||
patch_size,
|
patch_size,
|
||||||
|
image_height = None,
|
||||||
|
image_width = None,
|
||||||
num_latent_tokens = 4,
|
num_latent_tokens = 4,
|
||||||
encoder_depth = 4,
|
encoder_depth = 4,
|
||||||
decoder_depth = 4,
|
decoder_depth = 4,
|
||||||
@ -795,6 +795,9 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
|
self.image_height = image_height
|
||||||
|
self.image_width = image_width
|
||||||
|
|
||||||
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
# parameterize the decoder positional embeddings for MAE style training so it can be resolution agnostic
|
||||||
|
|
||||||
self.to_decoder_pos_emb = create_mlp(
|
self.to_decoder_pos_emb = create_mlp(
|
||||||
@ -825,6 +828,10 @@ class VideoTokenizer(Module):
|
|||||||
if self.has_lpips_loss:
|
if self.has_lpips_loss:
|
||||||
self.lpips = LPIPSLoss(lpips_loss_network)
|
self.lpips = LPIPSLoss(lpips_loss_network)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.zero.device
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def tokenize(
|
def tokenize(
|
||||||
self,
|
self,
|
||||||
@ -833,6 +840,102 @@ class VideoTokenizer(Module):
|
|||||||
self.eval()
|
self.eval()
|
||||||
return self.forward(video, return_latents = True)
|
return self.forward(video, return_latents = True)
|
||||||
|
|
||||||
|
def get_rotary_pos_emb(
|
||||||
|
self,
|
||||||
|
time,
|
||||||
|
num_patch_height,
|
||||||
|
num_patch_width
|
||||||
|
):
|
||||||
|
device = self.device
|
||||||
|
|
||||||
|
positions = stack(torch.meshgrid(
|
||||||
|
arange(time, device = device),
|
||||||
|
arange(num_patch_height, device = device),
|
||||||
|
arange(num_patch_width, device = device)
|
||||||
|
), dim = -1)
|
||||||
|
|
||||||
|
positions = rearrange(positions, 't h w p -> t (h w) p')
|
||||||
|
|
||||||
|
# give the latents an out of bounds position and assume the network will figure it out
|
||||||
|
|
||||||
|
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||||
|
|
||||||
|
positions = rearrange(positions, 't hw p -> (t hw) p')
|
||||||
|
|
||||||
|
return self.spacetime_rotary(positions)
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
latents, # (b t n d)
|
||||||
|
height = None,
|
||||||
|
width = None,
|
||||||
|
rotary_pos_emb = None
|
||||||
|
): # (b c t h w)
|
||||||
|
|
||||||
|
height = default(height, self.image_height)
|
||||||
|
width = default(width, self.image_width)
|
||||||
|
|
||||||
|
assert exists(height) and exists(width), f'image height and width need to be passed in when decoding latents'
|
||||||
|
|
||||||
|
batch, time, device = *latents.shape[:2], latents.device
|
||||||
|
|
||||||
|
use_flex = latents.is_cuda and exists(flex_attention)
|
||||||
|
|
||||||
|
num_patch_height = height // self.patch_size
|
||||||
|
num_patch_width = width // self.patch_size
|
||||||
|
|
||||||
|
if not exists(rotary_pos_emb):
|
||||||
|
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
||||||
|
|
||||||
|
# latents to tokens
|
||||||
|
|
||||||
|
latent_tokens = self.latents_to_decoder(latents)
|
||||||
|
|
||||||
|
# generate decoder positional embedding and concat the latent token
|
||||||
|
|
||||||
|
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
||||||
|
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
||||||
|
|
||||||
|
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
||||||
|
|
||||||
|
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
||||||
|
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
||||||
|
|
||||||
|
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
||||||
|
|
||||||
|
# pack time
|
||||||
|
|
||||||
|
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||||
|
|
||||||
|
seq_len = tokens.shape[-2]
|
||||||
|
|
||||||
|
# decoder attend
|
||||||
|
|
||||||
|
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
|
||||||
|
|
||||||
|
# decoder attention
|
||||||
|
|
||||||
|
for attn, ff in self.decoder_layers:
|
||||||
|
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
||||||
|
|
||||||
|
tokens = ff(tokens) + tokens
|
||||||
|
|
||||||
|
tokens = self.decoder_norm(tokens)
|
||||||
|
|
||||||
|
# unpack time
|
||||||
|
|
||||||
|
tokens = inverse_pack_time(tokens)
|
||||||
|
|
||||||
|
# unpack latents
|
||||||
|
|
||||||
|
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
||||||
|
|
||||||
|
# project back to patches
|
||||||
|
|
||||||
|
recon_video = self.tokens_to_patch(tokens)
|
||||||
|
|
||||||
|
return recon_video
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video, # (b c t h w)
|
video, # (b c t h w)
|
||||||
@ -855,21 +958,7 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# rotary positions
|
# rotary positions
|
||||||
|
|
||||||
positions = stack(torch.meshgrid(
|
rotary_pos_emb = self.get_rotary_pos_emb(time, num_patch_height, num_patch_width)
|
||||||
arange(time, device = device),
|
|
||||||
arange(num_patch_height, device = device),
|
|
||||||
arange(num_patch_width, device = device)
|
|
||||||
), dim = -1)
|
|
||||||
|
|
||||||
positions = rearrange(positions, 't h w p -> t (h w) p')
|
|
||||||
|
|
||||||
# give the latents an out of bounds position and assume the network will figure it out
|
|
||||||
|
|
||||||
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
|
||||||
|
|
||||||
positions = rearrange(positions, 't hw p -> (t hw) p')
|
|
||||||
|
|
||||||
rotary_pos_emb = self.spacetime_rotary(positions)
|
|
||||||
|
|
||||||
# masking
|
# masking
|
||||||
|
|
||||||
@ -941,48 +1030,7 @@ class VideoTokenizer(Module):
|
|||||||
if return_latents:
|
if return_latents:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
latent_tokens = self.latents_to_decoder(latents)
|
recon_video = self.decode(latents, height = height, width = width, rotary_pos_emb = rotary_pos_emb)
|
||||||
|
|
||||||
# generate decoder positional embedding and concat the latent token
|
|
||||||
|
|
||||||
spatial_pos_height = torch.linspace(-1., 1., num_patch_height, device = device)
|
|
||||||
spatial_pos_width = torch.linspace(-1., 1., num_patch_width, device = device)
|
|
||||||
|
|
||||||
space_height_width_coor = stack(torch.meshgrid(spatial_pos_height, spatial_pos_width, indexing = 'ij'), dim = -1)
|
|
||||||
|
|
||||||
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
|
||||||
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
|
||||||
|
|
||||||
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
|
||||||
|
|
||||||
# pack time
|
|
||||||
|
|
||||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
|
||||||
|
|
||||||
# decoder attend
|
|
||||||
|
|
||||||
decoder_attend_fn = get_attend_fn(use_flex, seq_len, seq_len, special_attend_only_itself = True)
|
|
||||||
|
|
||||||
# decoder attention
|
|
||||||
|
|
||||||
for attn, ff in self.decoder_layers:
|
|
||||||
tokens = attn(tokens, rotary_pos_emb = rotary_pos_emb, attend_fn = decoder_attend_fn) + tokens
|
|
||||||
|
|
||||||
tokens = ff(tokens) + tokens
|
|
||||||
|
|
||||||
tokens = self.decoder_norm(tokens)
|
|
||||||
|
|
||||||
# unpack time
|
|
||||||
|
|
||||||
tokens = inverse_pack_time(tokens)
|
|
||||||
|
|
||||||
# unpack latents
|
|
||||||
|
|
||||||
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
|
||||||
|
|
||||||
# project back to patches
|
|
||||||
|
|
||||||
recon_video = self.tokens_to_patch(tokens)
|
|
||||||
|
|
||||||
# losses
|
# losses
|
||||||
|
|
||||||
@ -1180,6 +1228,18 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
return list(set(params) - set(self.video_tokenizer.parameters()))
|
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
num_frames,
|
||||||
|
num_steps = 4,
|
||||||
|
image_height = None,
|
||||||
|
image_width = None
|
||||||
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
|
assert log(num_steps).is_integer(), f'number of steps must be a power of 2'
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@ -40,6 +40,9 @@ def test_e2e(
|
|||||||
latents = tokenizer(video, return_latents = True)
|
latents = tokenizer(video, return_latents = True)
|
||||||
assert latents.shape[-1] == 16
|
assert latents.shape[-1] == 16
|
||||||
|
|
||||||
|
recon = tokenizer.decode(latents, 256, 256)
|
||||||
|
assert recon.shape == video.shape
|
||||||
|
|
||||||
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
||||||
|
|
||||||
dynamics = DynamicsModel(
|
dynamics = DynamicsModel(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user