diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index de03645..f21027b 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -37,7 +37,7 @@ from accelerate import Accelerator import einx from einops import einsum, rearrange, repeat, reduce, pack, unpack -from einops.layers.torch import Rearrange +from einops.layers.torch import Rearrange, Reduce # flex attention - but will make sure it works if it is not available # may also end up crafting own custom flash attention kernel for this work @@ -991,9 +991,10 @@ class DynamicsModel(Module): dim, dim_latent, video_tokenizer: VideoTokenizer | None = None, - max_steps = 64, # K_max in paper - num_register_tokens = 8, # they claim register tokens led to better temporal consistency - num_spatial_tokens_per_latent = 2, # latents can be projected to greater number of tokens + max_steps = 64, # K_max in paper + num_register_tokens = 8, # they claim register tokens led to better temporal consistency + num_spatial_tokens = 2, # latents projected to greater number of spatial tokens + num_latent_tokens = None, num_tasks = 0, depth = 4, pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space) @@ -1014,12 +1015,49 @@ class DynamicsModel(Module): self.video_tokenizer = video_tokenizer - # spatial and register tokens + if exists(video_tokenizer): + num_latent_tokens = default(num_latent_tokens, video_tokenizer.num_latent_tokens) + assert video_tokenizer.num_latent_tokens == num_latent_tokens, f'`num_latent_tokens` must be the same for the tokenizer and dynamics model' - self.latents_to_spatial_tokens = Sequential( - Linear(dim_latent, dim * num_spatial_tokens_per_latent), - Rearrange('... (tokens d) -> ... tokens d', tokens = num_spatial_tokens_per_latent) - ) + assert exists(num_latent_tokens), '`num_latent_tokens` must be set' + + # spatial + + self.num_latent_tokens = num_latent_tokens + + if num_spatial_tokens >= num_latent_tokens: + assert divisible_by(num_spatial_tokens, num_latent_tokens) + + expand_factor = num_spatial_tokens // num_latent_tokens + + self.latents_to_spatial_tokens = Sequential( + Linear(dim_latent, dim * expand_factor), + Rearrange('... (s d) -> ... s d', s = expand_factor) + ) + + self.to_latent_pred = Sequential( + Reduce('b t n s d -> b t n d', 'mean'), + RMSNorm(dim), + LinearNoBias(dim, dim_latent) + ) + + else: + assert divisible_by(num_latent_tokens, num_spatial_tokens) + latent_tokens_to_space = num_latent_tokens // num_spatial_tokens + + self.latents_to_spatial_tokens = Sequential( + Rearrange('b t n d -> b t (n d)'), + Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens), + Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens) + ) + + self.to_latent_pred = Sequential( + RMSNorm(dim), + LinearNoBias(dim, dim_latent * latent_tokens_to_space), + Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space) + ) + + # register tokens self.num_register_tokens = num_register_tokens self.register_tokens = Parameter(torch.randn(num_register_tokens, dim) * 1e-2) @@ -1084,13 +1122,6 @@ class DynamicsModel(Module): self.layers = ModuleList(layers) self.is_time = is_time - # to prediction - - self.to_pred = Sequential( - RMSNorm(dim), - Linear(dim, dim_latent) - ) - def parameter(self): params = super().parameters() @@ -1121,6 +1152,8 @@ class DynamicsModel(Module): if latents.ndim == 3: latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case + assert latents.shape[-2] == self.num_latent_tokens + # variables batch, time, device = *latents.shape[:2], latents.device @@ -1262,9 +1295,7 @@ class DynamicsModel(Module): space_tokens = inverse_pack_space_per_latent(space_tokens) - pooled = reduce(space_tokens, 'b t nl s d -> b t nl d', 'mean') - - pred = self.to_pred(pooled) + pred = self.to_latent_pred(space_tokens) return pred diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index bb0bb1b..d11cc53 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -7,16 +7,24 @@ import torch @param('dynamics_with_video_input', (False, True)) @param('prob_no_shortcut_train', (None, 0., 1.)) @param('add_task_embeds', (False, True)) +@param('num_spatial_tokens', (2, 8)) def test_e2e( pred_orig_latent, grouped_query_attn, dynamics_with_video_input, prob_no_shortcut_train, - add_task_embeds + add_task_embeds, + num_spatial_tokens ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel - tokenizer = VideoTokenizer(512, dim_latent = 32, patch_size = 32) + tokenizer = VideoTokenizer( + 512, + dim_latent = 32, + patch_size = 32, + num_latent_tokens = 4 + ) + video = torch.randn(2, 3, 4, 256, 256) loss = tokenizer(video) @@ -33,6 +41,8 @@ def test_e2e( dim_latent = 32, max_steps = 64, num_tasks = 4, + num_latent_tokens = 4, + num_spatial_tokens = num_spatial_tokens, pred_orig_latent = pred_orig_latent, attn_kwargs = dict( heads = heads,