given the VAT paper, add multiple video streams (third person, wrist camera, etc), geared for robotics. need to manage an extra dimension for multiple viewpoints

This commit is contained in:
lucidrains 2025-10-25 09:20:55 -07:00
parent a9b728c611
commit 4ce82f34df
3 changed files with 72 additions and 19 deletions

View File

@ -45,6 +45,7 @@ from assoc_scan import AssocScan
# vc - video channels
# vh, vw - video height and width
# mtp - multi token prediction length
# v - video viewpoints
import einx
from einx import add, multiply
@ -1735,6 +1736,7 @@ class DynamicsWorldModel(Module):
num_latent_tokens = None,
num_agents = 1,
num_tasks = 0,
num_video_views = 1,
dim_proprio = None,
reward_encoder_kwargs: dict = dict(),
depth = 4,
@ -1800,7 +1802,7 @@ class DynamicsWorldModel(Module):
)
self.to_latent_pred = Sequential(
Reduce('b t n s d -> b t n d', 'mean'),
Reduce('b t v n s d -> b t v n d', 'mean'),
RMSNorm(dim),
LinearNoBias(dim, dim_latent)
)
@ -1810,17 +1812,27 @@ class DynamicsWorldModel(Module):
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
self.latents_to_spatial_tokens = Sequential(
Rearrange('b t n d -> b t (n d)'),
Rearrange('... n d -> ... (n d)'),
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
)
self.to_latent_pred = Sequential(
RMSNorm(dim),
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
)
# number of video views, for robotics, which could have third person + wrist camera at least
assert num_video_views >= 1
self.video_has_multi_view = num_video_views > 1
self.num_video_views = num_video_views
if self.video_has_multi_view:
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
# proprioception
self.has_proprio = exists(dim_proprio)
@ -2318,7 +2330,7 @@ class DynamicsWorldModel(Module):
# denoising
# teacher forcing to start with
latents = empty((batch_size, 0, *latent_shape), device = self.device)
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
past_latents_context_noise = latents.clone()
@ -2354,7 +2366,7 @@ class DynamicsWorldModel(Module):
curr_time_steps = latents.shape[1]
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
noised_proprio = None
@ -2365,12 +2377,12 @@ class DynamicsWorldModel(Module):
is_last_step = (step + 1) == num_steps
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
# noising past latent context
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
# handle proprio
@ -2395,6 +2407,7 @@ class DynamicsWorldModel(Module):
proprio = noised_proprio_with_context,
time_kv_cache = time_kv_cache,
latent_is_noised = True,
latent_has_view_dim = True,
return_pred_only = True,
return_intermediates = True,
)
@ -2409,12 +2422,11 @@ class DynamicsWorldModel(Module):
# unpack pred
_, pred = unpack(pred, pack_context_shape, 'b * n d')
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
if has_proprio:
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
# derive flow, based on whether in x-space or not
def denoise_step(pred, noised, signal_levels):
@ -2507,12 +2519,26 @@ class DynamicsWorldModel(Module):
video = None
if return_decoded_video:
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
video = self.video_tokenizer.decode(
latents,
latents_for_video,
height = image_height,
width = image_width
)
video = unpack_view(video, '* t c vh vw')
# remove the lone view dimension
if not self.video_has_multi_view:
latents = rearrange(latents, 'b t 1 ... -> b t ...')
if exists(video):
video = rearrange(video, 'b 1 ... -> b ...')
# only return video or latent if not requesting anything else, for first stage training
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
@ -2553,8 +2579,8 @@ class DynamicsWorldModel(Module):
def forward(
self,
*,
video = None, # (b c t vh vw)
latents = None, # (b t n d) | (b t d)
video = None, # (b v? c t vh vw)
latents = None, # (b t v? n d) | (b t v? d)
signal_levels = None, # () | (b) | (b t)
step_sizes = None, # () | (b)
step_sizes_log2 = None, # () | (b)
@ -2572,21 +2598,39 @@ class DynamicsWorldModel(Module):
return_all_losses = False,
return_intermediates = False,
add_autoregressive_action_loss = False,
update_loss_ema = None
update_loss_ema = None,
latent_has_view_dim = False
):
# handle video or latents
assert exists(video) ^ exists(latents)
# standardize view dimension
if not self.video_has_multi_view:
if exists(video):
video = rearrange(video, 'b ... -> b 1 ...')
if exists(latents) and not latent_has_view_dim:
latents = rearrange(latents, 'b t ... -> b t 1 ...')
# if raw video passed in, tokenize
if exists(video):
assert video.ndim == 6
video, unpack_views = pack_one(video, '* c t vh vw')
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
latents = self.video_tokenizer.tokenize(video)
latents = unpack_views(latents, '* t n d')
latents = rearrange(latents, 'b v t n d -> b t v n d')
if latents.ndim == 3:
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
if latents.ndim == 4:
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
assert latents.shape[-2:] == self.latent_shape
assert latents.shape[2] == self.num_video_views
# variables
@ -2769,6 +2813,7 @@ class DynamicsWorldModel(Module):
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
# latents to spatial tokens
space_tokens = self.latents_to_spatial_tokens(noised_latents)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.70"
version = "0.0.71"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -668,7 +668,10 @@ def test_online_rl(
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
def test_proprioception():
@param('num_video_views', (1, 2))
def test_proprioception(
num_video_views
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
@ -693,11 +696,16 @@ def test_proprioception():
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_video_views = num_video_views,
num_discrete_actions = 4,
num_residual_streams = 1
)
video = torch.randn(2, 3, 10, 256, 256)
if num_video_views > 1:
video = torch.randn(2, num_video_views, 3, 10, 256, 256)
else:
video = torch.randn(2, 3, 10, 256, 256)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))