an anonymous researcher pointed out that the video tokenizer may be using multiple latents per frame
This commit is contained in:
parent
338def693d
commit
bfbecb4968
@ -98,6 +98,15 @@ def pad_at_dim(
|
||||
zeros = ((0, 0) * dims_from_right)
|
||||
return F.pad(t, (*zeros, *pad), value = value)
|
||||
|
||||
def align_dims_left(t, aligned_to):
|
||||
shape = t.shape
|
||||
num_right_dims = aligned_to.ndim - t.ndim
|
||||
|
||||
if num_right_dims < 0:
|
||||
return
|
||||
|
||||
return t.reshape(*shape, *((1,) * num_right_dims))
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
@ -677,6 +686,7 @@ class VideoTokenizer(Module):
|
||||
dim,
|
||||
dim_latent,
|
||||
patch_size,
|
||||
num_latent_tokens = 4,
|
||||
encoder_depth = 4,
|
||||
decoder_depth = 4,
|
||||
attn_kwargs: dict = dict(),
|
||||
@ -701,7 +711,9 @@ class VideoTokenizer(Module):
|
||||
|
||||
# special tokens
|
||||
|
||||
self.latent_token = Parameter(torch.randn(dim) * 1e-2)
|
||||
assert num_latent_tokens >= 1
|
||||
self.num_latent_tokens = num_latent_tokens
|
||||
self.latent_tokens = Parameter(torch.randn(num_latent_tokens, dim) * 1e-2)
|
||||
|
||||
# mae masking - Kaiming He paper from long ago
|
||||
|
||||
@ -829,7 +841,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
# give the latents an out of bounds position and assume the network will figure it out
|
||||
|
||||
positions = pad_at_dim(positions, (0, 1), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||
positions = pad_at_dim(positions, (0, self.num_latent_tokens), dim = -2, value = -1) # todo - make this value configurable, and ultimately craft own flash attention function where certain positions can be unrotated
|
||||
|
||||
positions = rearrange(positions, 't hw p -> (t hw) p')
|
||||
|
||||
@ -855,7 +867,7 @@ class VideoTokenizer(Module):
|
||||
|
||||
# add the latent
|
||||
|
||||
latents = repeat(self.latent_token, 'd -> b t d', b = tokens.shape[0], t = tokens.shape[1])
|
||||
latents = repeat(self.latent_tokens, 'n d -> b t n d', b = tokens.shape[0], t = tokens.shape[1])
|
||||
|
||||
tokens, packed_latent_shape = pack((tokens, latents), 'b t * d')
|
||||
|
||||
@ -917,7 +929,11 @@ class VideoTokenizer(Module):
|
||||
decoder_pos_emb = self.to_decoder_pos_emb(space_height_width_coor)
|
||||
decoder_pos_emb = repeat(decoder_pos_emb, '... -> b t ...', b = batch, t = time)
|
||||
|
||||
tokens, _ = pack((decoder_pos_emb, latent_tokens), 'b * d')
|
||||
tokens, packed_latent_shape = pack((decoder_pos_emb, latent_tokens), 'b t * d')
|
||||
|
||||
# pack time
|
||||
|
||||
tokens, inverse_pack_time = pack_one(tokens, 'b * d')
|
||||
|
||||
# decoder attend
|
||||
|
||||
@ -936,13 +952,9 @@ class VideoTokenizer(Module):
|
||||
|
||||
tokens = inverse_pack_time(tokens)
|
||||
|
||||
# excise latents
|
||||
# unpack latents
|
||||
|
||||
tokens = tokens[..., :-1, :]
|
||||
|
||||
# unpack space
|
||||
|
||||
tokens = inverse_pack_space(tokens)
|
||||
tokens, latent_tokens = unpack(tokens, packed_latent_shape, 'b t * d')
|
||||
|
||||
# project back to patches
|
||||
|
||||
@ -979,9 +991,9 @@ class DynamicsModel(Module):
|
||||
dim,
|
||||
dim_latent,
|
||||
video_tokenizer: VideoTokenizer | None = None,
|
||||
max_steps = 64, # K_max in paper
|
||||
num_spatial_tokens = 32, # latents were projected into spatial tokens, and presumably pooled back for the final prediction (or one special one does the x-prediction)
|
||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||
max_steps = 64, # K_max in paper
|
||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||
num_spatial_tokens_per_latent = 2, # latents can be projected to greater number of tokens
|
||||
num_tasks = 0,
|
||||
depth = 4,
|
||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||
@ -1005,10 +1017,11 @@ class DynamicsModel(Module):
|
||||
# spatial and register tokens
|
||||
|
||||
self.latents_to_spatial_tokens = Sequential(
|
||||
Linear(dim_latent, dim * num_spatial_tokens),
|
||||
Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens)
|
||||
Linear(dim_latent, dim * num_spatial_tokens_per_latent),
|
||||
Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens_per_latent)
|
||||
)
|
||||
|
||||
self.num_register_tokens = num_register_tokens
|
||||
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||
|
||||
# signal and step sizes
|
||||
@ -1039,15 +1052,6 @@ class DynamicsModel(Module):
|
||||
self.num_tasks = num_tasks
|
||||
self.task_embed = nn.Embedding(num_tasks, dim)
|
||||
|
||||
# calculate "space" seq len
|
||||
|
||||
self.space_seq_len = (
|
||||
1 # action / agent token
|
||||
+ 1 # signal + step
|
||||
+ num_register_tokens
|
||||
+ num_spatial_tokens
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
self.attn_softclamp_value = attn_softclamp_value
|
||||
@ -1099,7 +1103,7 @@ class DynamicsModel(Module):
|
||||
self,
|
||||
*,
|
||||
video = None,
|
||||
latents = None, # (b t d)
|
||||
latents = None, # (b t n d) | (b t d)
|
||||
signal_levels = None, # (b t)
|
||||
step_sizes_log2 = None, # (b)
|
||||
tasks = None, # (b)
|
||||
@ -1114,6 +1118,11 @@ class DynamicsModel(Module):
|
||||
|
||||
latents = self.video_tokenizer.tokenize(video)
|
||||
|
||||
if latents.ndim == 3:
|
||||
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
||||
|
||||
# variables
|
||||
|
||||
batch, time, device = *latents.shape[:2], latents.device
|
||||
|
||||
# flow related
|
||||
@ -1146,7 +1155,11 @@ class DynamicsModel(Module):
|
||||
|
||||
# times is from 0 to 1
|
||||
|
||||
times = rearrange(signal_levels.float() / self.max_steps, 'b t -> b t 1')
|
||||
def get_times_from_signal_level(signal_levels):
|
||||
times = signal_levels.float() / self.max_steps
|
||||
return align_dims_left(times, latents)
|
||||
|
||||
times = get_times_from_signal_level(signal_levels)
|
||||
|
||||
# noise from 0 as noise to 1 as data
|
||||
|
||||
@ -1169,6 +1182,10 @@ class DynamicsModel(Module):
|
||||
|
||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||
|
||||
space_tokens, inverse_pack_space_per_latent = pack_one(space_tokens, 'b t * d')
|
||||
|
||||
num_spatial_tokens = space_tokens.shape[-2]
|
||||
|
||||
# pack to tokens
|
||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||
|
||||
@ -1200,7 +1217,14 @@ class DynamicsModel(Module):
|
||||
|
||||
attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device)
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = self.space_seq_len, k_seq_len = self.space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
space_seq_len = (
|
||||
1 # action / agent token
|
||||
+ 1 # signal + step
|
||||
+ self.num_register_tokens
|
||||
+ num_spatial_tokens
|
||||
)
|
||||
|
||||
space_attend = get_attend_fn(causal = False, seq_len = space_seq_len, k_seq_len = space_seq_len, num_special_tokens = 1, **attend_kwargs) # space has an agent token on the right-hand side for reinforcement learning - cannot be attended to by modality
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
@ -1236,7 +1260,9 @@ class DynamicsModel(Module):
|
||||
|
||||
# pooling
|
||||
|
||||
pooled = reduce(space_tokens, 'b t s d -> b t d', 'mean')
|
||||
space_tokens = inverse_pack_space_per_latent(space_tokens)
|
||||
|
||||
pooled = reduce(space_tokens, 'b t nl s d -> b t nl d', 'mean')
|
||||
|
||||
pred = self.to_pred(pooled)
|
||||
|
||||
@ -1285,21 +1311,24 @@ class DynamicsModel(Module):
|
||||
if is_v_space_pred:
|
||||
first_step_pred_flow = first_step_pred
|
||||
else:
|
||||
first_times = signal_levels[..., None].float() / self.max_steps
|
||||
first_times = get_times_from_signal_level(signal_levels)
|
||||
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
|
||||
|
||||
# take a half step
|
||||
|
||||
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size[:, None, None] / self.max_steps)
|
||||
half_step_size_align_left = align_dims_left(half_step_size, noised_latents)
|
||||
|
||||
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
||||
|
||||
# get second prediction for b''
|
||||
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels + half_step_size[:, None], step_sizes_log2_minus_one, agent_tokens)
|
||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens)
|
||||
|
||||
if is_v_space_pred:
|
||||
second_step_pred_flow = second_step_pred
|
||||
else:
|
||||
second_times = signal_levels[..., None].float() / self.max_steps
|
||||
second_times = get_times_from_signal_level(signal_levels_plus_half_step)
|
||||
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
|
||||
|
||||
# pred target is sg(b' + b'') / 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user