From 4ce82f34df213497e4865caa02ddf4b6743ba81c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 25 Oct 2025 09:20:55 -0700 Subject: [PATCH] given the VAT paper, add multiple video streams (third person, wrist camera, etc), geared for robotics. need to manage an extra dimension for multiple viewpoints --- dreamer4/dreamer4.py | 77 ++++++++++++++++++++++++++++++++++--------- pyproject.toml | 2 +- tests/test_dreamer.py | 12 +++++-- 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index dc81bb3..9c22615 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -45,6 +45,7 @@ from assoc_scan import AssocScan # vc - video channels # vh, vw - video height and width # mtp - multi token prediction length +# v - video viewpoints import einx from einx import add, multiply @@ -1735,6 +1736,7 @@ class DynamicsWorldModel(Module): num_latent_tokens = None, num_agents = 1, num_tasks = 0, + num_video_views = 1, dim_proprio = None, reward_encoder_kwargs: dict = dict(), depth = 4, @@ -1800,7 +1802,7 @@ class DynamicsWorldModel(Module): ) self.to_latent_pred = Sequential( - Reduce('b t n s d -> b t n d', 'mean'), + Reduce('b t v n s d -> b t v n d', 'mean'), RMSNorm(dim), LinearNoBias(dim, dim_latent) ) @@ -1810,17 +1812,27 @@ class DynamicsWorldModel(Module): latent_tokens_to_space = num_latent_tokens // num_spatial_tokens self.latents_to_spatial_tokens = Sequential( - Rearrange('b t n d -> b t (n d)'), + Rearrange('... n d -> ... (n d)'), Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens), - Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens) + Rearrange('... (s d) -> ... s d', s = num_spatial_tokens) ) self.to_latent_pred = Sequential( RMSNorm(dim), LinearNoBias(dim, dim_latent * latent_tokens_to_space), - Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space) + Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space) ) + # number of video views, for robotics, which could have third person + wrist camera at least + + assert num_video_views >= 1 + self.video_has_multi_view = num_video_views > 1 + + self.num_video_views = num_video_views + + if self.video_has_multi_view: + self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2) + # proprioception self.has_proprio = exists(dim_proprio) @@ -2318,7 +2330,7 @@ class DynamicsWorldModel(Module): # denoising # teacher forcing to start with - latents = empty((batch_size, 0, *latent_shape), device = self.device) + latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device) past_latents_context_noise = latents.clone() @@ -2354,7 +2366,7 @@ class DynamicsWorldModel(Module): curr_time_steps = latents.shape[1] - noised_latent = randn((batch_size, 1, *latent_shape), device = self.device) + noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device) noised_proprio = None @@ -2365,12 +2377,12 @@ class DynamicsWorldModel(Module): is_last_step = (step + 1) == num_steps signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) - + # noising past latent context noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8) - noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d') + noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d') # handle proprio @@ -2395,6 +2407,7 @@ class DynamicsWorldModel(Module): proprio = noised_proprio_with_context, time_kv_cache = time_kv_cache, latent_is_noised = True, + latent_has_view_dim = True, return_pred_only = True, return_intermediates = True, ) @@ -2409,12 +2422,11 @@ class DynamicsWorldModel(Module): # unpack pred - _, pred = unpack(pred, pack_context_shape, 'b * n d') + _, pred = unpack(pred, pack_context_shape, 'b * v n d') if has_proprio: _, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d') - # derive flow, based on whether in x-space or not def denoise_step(pred, noised, signal_levels): @@ -2507,12 +2519,26 @@ class DynamicsWorldModel(Module): video = None if return_decoded_video: + + latents_for_video = rearrange(latents, 'b t v n d -> b v t n d') + latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d') + video = self.video_tokenizer.decode( - latents, + latents_for_video, height = image_height, width = image_width ) + video = unpack_view(video, '* t c vh vw') + + # remove the lone view dimension + + if not self.video_has_multi_view: + latents = rearrange(latents, 'b t 1 ... -> b t ...') + + if exists(video): + video = rearrange(video, 'b 1 ... -> b ...') + # only return video or latent if not requesting anything else, for first stage training if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio): @@ -2553,8 +2579,8 @@ class DynamicsWorldModel(Module): def forward( self, *, - video = None, # (b c t vh vw) - latents = None, # (b t n d) | (b t d) + video = None, # (b v? c t vh vw) + latents = None, # (b t v? n d) | (b t v? d) signal_levels = None, # () | (b) | (b t) step_sizes = None, # () | (b) step_sizes_log2 = None, # () | (b) @@ -2572,21 +2598,39 @@ class DynamicsWorldModel(Module): return_all_losses = False, return_intermediates = False, add_autoregressive_action_loss = False, - update_loss_ema = None + update_loss_ema = None, + latent_has_view_dim = False ): # handle video or latents assert exists(video) ^ exists(latents) + # standardize view dimension + + if not self.video_has_multi_view: + if exists(video): + video = rearrange(video, 'b ... -> b 1 ...') + + if exists(latents) and not latent_has_view_dim: + latents = rearrange(latents, 'b t ... -> b t 1 ...') + + # if raw video passed in, tokenize + if exists(video): + assert video.ndim == 6 + + video, unpack_views = pack_one(video, '* c t vh vw') assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model' latents = self.video_tokenizer.tokenize(video) + latents = unpack_views(latents, '* t n d') + latents = rearrange(latents, 'b v t n d -> b t v n d') - if latents.ndim == 3: - latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case + if latents.ndim == 4: + latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case assert latents.shape[-2:] == self.latent_shape + assert latents.shape[2] == self.num_video_views # variables @@ -2769,6 +2813,7 @@ class DynamicsWorldModel(Module): # main function, needs to be defined as such for shortcut training - additional calls for consistency loss def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False): + # latents to spatial tokens space_tokens = self.latents_to_spatial_tokens(noised_latents) diff --git a/pyproject.toml b/pyproject.toml index a11bf77..d0eedcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.70" +version = "0.0.71" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 39ed0ce..a208b66 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -668,7 +668,10 @@ def test_online_rl( trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized) -def test_proprioception(): +@param('num_video_views', (1, 2)) +def test_proprioception( + num_video_views +): from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel tokenizer = VideoTokenizer( @@ -693,11 +696,16 @@ def test_proprioception(): dim_latent = 32, dim_proprio = 21, num_tasks = 4, + num_video_views = num_video_views, num_discrete_actions = 4, num_residual_streams = 1 ) - video = torch.randn(2, 3, 10, 256, 256) + if num_video_views > 1: + video = torch.randn(2, num_video_views, 3, 10, 256, 256) + else: + video = torch.randn(2, 3, 10, 256, 256) + rewards = torch.randn(2, 10) proprio = torch.randn(2, 10, 21) discrete_actions = torch.randint(0, 4, (2, 10, 1))