handle spatial tokens less than latent tokens in dynamics model
This commit is contained in:
parent
bfbecb4968
commit
25b8de91cc
@ -37,7 +37,7 @@ from accelerate import Accelerator
|
|||||||
|
|
||||||
import einx
|
import einx
|
||||||
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
|
||||||
# flex attention - but will make sure it works if it is not available
|
# flex attention - but will make sure it works if it is not available
|
||||||
# may also end up crafting own custom flash attention kernel for this work
|
# may also end up crafting own custom flash attention kernel for this work
|
||||||
@ -993,7 +993,8 @@ class DynamicsModel(Module):
|
|||||||
video_tokenizer: VideoTokenizer | None = None,
|
video_tokenizer: VideoTokenizer | None = None,
|
||||||
max_steps = 64, # K_max in paper
|
max_steps = 64, # K_max in paper
|
||||||
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
num_register_tokens = 8, # they claim register tokens led to better temporal consistency
|
||||||
num_spatial_tokens_per_latent = 2, # latents can be projected to greater number of tokens
|
num_spatial_tokens = 2, # latents projected to greater number of spatial tokens
|
||||||
|
num_latent_tokens = None,
|
||||||
num_tasks = 0,
|
num_tasks = 0,
|
||||||
depth = 4,
|
depth = 4,
|
||||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||||
@ -1014,13 +1015,50 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
self.video_tokenizer = video_tokenizer
|
self.video_tokenizer = video_tokenizer
|
||||||
|
|
||||||
# spatial and register tokens
|
if exists(video_tokenizer):
|
||||||
|
num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens)
|
||||||
|
assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model'
|
||||||
|
|
||||||
|
assert exists(num_latent_tokens), '`num_latent_tokens` must be set'
|
||||||
|
|
||||||
|
# spatial
|
||||||
|
|
||||||
|
self.num_latent_tokens = num_latent_tokens
|
||||||
|
|
||||||
|
if num_spatial_tokens >= num_latent_tokens:
|
||||||
|
assert divisible_by(num_spatial_tokens, num_latent_tokens)
|
||||||
|
|
||||||
|
expand_factor = num_spatial_tokens // num_latent_tokens
|
||||||
|
|
||||||
self.latents_to_spatial_tokens = Sequential(
|
self.latents_to_spatial_tokens = Sequential(
|
||||||
Linear(dim_latent, dim * num_spatial_tokens_per_latent),
|
Linear(dim_latent, dim * expand_factor),
|
||||||
Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens_per_latent)
|
Rearrange('... (s d) -> ... s d', s = expand_factor)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.to_latent_pred = Sequential(
|
||||||
|
Reduce('b t n s d -> b t n d', 'mean'),
|
||||||
|
RMSNorm(dim),
|
||||||
|
LinearNoBias(dim, dim_latent)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert divisible_by(num_latent_tokens, num_spatial_tokens)
|
||||||
|
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
||||||
|
|
||||||
|
self.latents_to_spatial_tokens = Sequential(
|
||||||
|
Rearrange('b t n d -> b t (n d)'),
|
||||||
|
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
||||||
|
Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_latent_pred = Sequential(
|
||||||
|
RMSNorm(dim),
|
||||||
|
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
||||||
|
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
||||||
|
)
|
||||||
|
|
||||||
|
# register tokens
|
||||||
|
|
||||||
self.num_register_tokens = num_register_tokens
|
self.num_register_tokens = num_register_tokens
|
||||||
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
|
||||||
|
|
||||||
@ -1084,13 +1122,6 @@ class DynamicsModel(Module):
|
|||||||
self.layers = ModuleList(layers)
|
self.layers = ModuleList(layers)
|
||||||
self.is_time = is_time
|
self.is_time = is_time
|
||||||
|
|
||||||
# to prediction
|
|
||||||
|
|
||||||
self.to_pred = Sequential(
|
|
||||||
RMSNorm(dim),
|
|
||||||
Linear(dim, dim_latent)
|
|
||||||
)
|
|
||||||
|
|
||||||
def parameter(self):
|
def parameter(self):
|
||||||
params = super().parameters()
|
params = super().parameters()
|
||||||
|
|
||||||
@ -1121,6 +1152,8 @@ class DynamicsModel(Module):
|
|||||||
if latents.ndim == 3:
|
if latents.ndim == 3:
|
||||||
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
||||||
|
|
||||||
|
assert latents.shape[-2] == self.num_latent_tokens
|
||||||
|
|
||||||
# variables
|
# variables
|
||||||
|
|
||||||
batch, time, device = *latents.shape[:2], latents.device
|
batch, time, device = *latents.shape[:2], latents.device
|
||||||
@ -1262,9 +1295,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
space_tokens = inverse_pack_space_per_latent(space_tokens)
|
space_tokens = inverse_pack_space_per_latent(space_tokens)
|
||||||
|
|
||||||
pooled = reduce(space_tokens, 'b t nl s d -> b t nl d', 'mean')
|
pred = self.to_latent_pred(space_tokens)
|
||||||
|
|
||||||
pred = self.to_pred(pooled)
|
|
||||||
|
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
|||||||
@ -7,16 +7,24 @@ import torch
|
|||||||
@param('dynamics_with_video_input', (False, True))
|
@param('dynamics_with_video_input', (False, True))
|
||||||
@param('prob_no_shortcut_train', (None, 0., 1.))
|
@param('prob_no_shortcut_train', (None, 0., 1.))
|
||||||
@param('add_task_embeds', (False, True))
|
@param('add_task_embeds', (False, True))
|
||||||
|
@param('num_spatial_tokens', (2, 8))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
grouped_query_attn,
|
grouped_query_attn,
|
||||||
dynamics_with_video_input,
|
dynamics_with_video_input,
|
||||||
prob_no_shortcut_train,
|
prob_no_shortcut_train,
|
||||||
add_task_embeds
|
add_task_embeds,
|
||||||
|
num_spatial_tokens
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32)
|
tokenizer = VideoTokenizer(
|
||||||
|
512,
|
||||||
|
dim_latent = 32,
|
||||||
|
patch_size = 32,
|
||||||
|
num_latent_tokens = 4
|
||||||
|
)
|
||||||
|
|
||||||
video = torch.randn(2, 3, 4, 256, 256)
|
video = torch.randn(2, 3, 4, 256, 256)
|
||||||
|
|
||||||
loss = tokenizer(video)
|
loss = tokenizer(video)
|
||||||
@ -33,6 +41,8 @@ def test_e2e(
|
|||||||
dim_latent = 32,
|
dim_latent = 32,
|
||||||
max_steps = 64,
|
max_steps = 64,
|
||||||
num_tasks = 4,
|
num_tasks = 4,
|
||||||
|
num_latent_tokens = 4,
|
||||||
|
num_spatial_tokens = num_spatial_tokens,
|
||||||
pred_orig_latent = pred_orig_latent,
|
pred_orig_latent = pred_orig_latent,
|
||||||
attn_kwargs = dict(
|
attn_kwargs = dict(
|
||||||
heads = heads,
|
heads = heads,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user