given the VAT paper, add multiple video streams (third person, wrist camera, etc), geared for robotics. need to manage an extra dimension for multiple viewpoints
This commit is contained in:
parent
a9b728c611
commit
4ce82f34df
@ -45,6 +45,7 @@ from assoc_scan import AssocScan
|
||||
# vc - video channels
|
||||
# vh, vw - video height and width
|
||||
# mtp - multi token prediction length
|
||||
# v - video viewpoints
|
||||
|
||||
import einx
|
||||
from einx import add, multiply
|
||||
@ -1735,6 +1736,7 @@ class DynamicsWorldModel(Module):
|
||||
num_latent_tokens = None,
|
||||
num_agents = 1,
|
||||
num_tasks = 0,
|
||||
num_video_views = 1,
|
||||
dim_proprio = None,
|
||||
reward_encoder_kwargs: dict = dict(),
|
||||
depth = 4,
|
||||
@ -1800,7 +1802,7 @@ class DynamicsWorldModel(Module):
|
||||
)
|
||||
|
||||
self.to_latent_pred = Sequential(
|
||||
Reduce('b t n s d -> b t n d', 'mean'),
|
||||
Reduce('b t v n s d -> b t v n d', 'mean'),
|
||||
RMSNorm(dim),
|
||||
LinearNoBias(dim, dim_latent)
|
||||
)
|
||||
@ -1810,17 +1812,27 @@ class DynamicsWorldModel(Module):
|
||||
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
||||
|
||||
self.latents_to_spatial_tokens = Sequential(
|
||||
Rearrange('b t n d -> b t (n d)'),
|
||||
Rearrange('... n d -> ... (n d)'),
|
||||
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
||||
Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
|
||||
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
|
||||
)
|
||||
|
||||
self.to_latent_pred = Sequential(
|
||||
RMSNorm(dim),
|
||||
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
||||
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
||||
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
|
||||
)
|
||||
|
||||
# number of video views, for robotics, which could have third person + wrist camera at least
|
||||
|
||||
assert num_video_views >= 1
|
||||
self.video_has_multi_view = num_video_views > 1
|
||||
|
||||
self.num_video_views = num_video_views
|
||||
|
||||
if self.video_has_multi_view:
|
||||
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
|
||||
|
||||
# proprioception
|
||||
|
||||
self.has_proprio = exists(dim_proprio)
|
||||
@ -2318,7 +2330,7 @@ class DynamicsWorldModel(Module):
|
||||
# denoising
|
||||
# teacher forcing to start with
|
||||
|
||||
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
||||
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
|
||||
|
||||
past_latents_context_noise = latents.clone()
|
||||
|
||||
@ -2354,7 +2366,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
curr_time_steps = latents.shape[1]
|
||||
|
||||
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
||||
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
||||
|
||||
noised_proprio = None
|
||||
|
||||
@ -2365,12 +2377,12 @@ class DynamicsWorldModel(Module):
|
||||
is_last_step = (step + 1) == num_steps
|
||||
|
||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
|
||||
|
||||
# noising past latent context
|
||||
|
||||
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||
|
||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
|
||||
|
||||
# handle proprio
|
||||
|
||||
@ -2395,6 +2407,7 @@ class DynamicsWorldModel(Module):
|
||||
proprio = noised_proprio_with_context,
|
||||
time_kv_cache = time_kv_cache,
|
||||
latent_is_noised = True,
|
||||
latent_has_view_dim = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True,
|
||||
)
|
||||
@ -2409,12 +2422,11 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# unpack pred
|
||||
|
||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
|
||||
|
||||
if has_proprio:
|
||||
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
||||
|
||||
|
||||
# derive flow, based on whether in x-space or not
|
||||
|
||||
def denoise_step(pred, noised, signal_levels):
|
||||
@ -2507,12 +2519,26 @@ class DynamicsWorldModel(Module):
|
||||
video = None
|
||||
|
||||
if return_decoded_video:
|
||||
|
||||
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
|
||||
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
|
||||
|
||||
video = self.video_tokenizer.decode(
|
||||
latents,
|
||||
latents_for_video,
|
||||
height = image_height,
|
||||
width = image_width
|
||||
)
|
||||
|
||||
video = unpack_view(video, '* t c vh vw')
|
||||
|
||||
# remove the lone view dimension
|
||||
|
||||
if not self.video_has_multi_view:
|
||||
latents = rearrange(latents, 'b t 1 ... -> b t ...')
|
||||
|
||||
if exists(video):
|
||||
video = rearrange(video, 'b 1 ... -> b ...')
|
||||
|
||||
# only return video or latent if not requesting anything else, for first stage training
|
||||
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
||||
@ -2553,8 +2579,8 @@ class DynamicsWorldModel(Module):
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
video = None, # (b c t vh vw)
|
||||
latents = None, # (b t n d) | (b t d)
|
||||
video = None, # (b v? c t vh vw)
|
||||
latents = None, # (b t v? n d) | (b t v? d)
|
||||
signal_levels = None, # () | (b) | (b t)
|
||||
step_sizes = None, # () | (b)
|
||||
step_sizes_log2 = None, # () | (b)
|
||||
@ -2572,21 +2598,39 @@ class DynamicsWorldModel(Module):
|
||||
return_all_losses = False,
|
||||
return_intermediates = False,
|
||||
add_autoregressive_action_loss = False,
|
||||
update_loss_ema = None
|
||||
update_loss_ema = None,
|
||||
latent_has_view_dim = False
|
||||
):
|
||||
# handle video or latents
|
||||
|
||||
assert exists(video) ^ exists(latents)
|
||||
|
||||
# standardize view dimension
|
||||
|
||||
if not self.video_has_multi_view:
|
||||
if exists(video):
|
||||
video = rearrange(video, 'b ... -> b 1 ...')
|
||||
|
||||
if exists(latents) and not latent_has_view_dim:
|
||||
latents = rearrange(latents, 'b t ... -> b t 1 ...')
|
||||
|
||||
# if raw video passed in, tokenize
|
||||
|
||||
if exists(video):
|
||||
assert video.ndim == 6
|
||||
|
||||
video, unpack_views = pack_one(video, '* c t vh vw')
|
||||
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
||||
|
||||
latents = self.video_tokenizer.tokenize(video)
|
||||
latents = unpack_views(latents, '* t n d')
|
||||
latents = rearrange(latents, 'b v t n d -> b t v n d')
|
||||
|
||||
if latents.ndim == 3:
|
||||
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
||||
if latents.ndim == 4:
|
||||
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
||||
|
||||
assert latents.shape[-2:] == self.latent_shape
|
||||
assert latents.shape[2] == self.num_video_views
|
||||
|
||||
# variables
|
||||
|
||||
@ -2769,6 +2813,7 @@ class DynamicsWorldModel(Module):
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
||||
|
||||
# latents to spatial tokens
|
||||
|
||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.70"
|
||||
version = "0.0.71"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -668,7 +668,10 @@ def test_online_rl(
|
||||
|
||||
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
||||
|
||||
def test_proprioception():
|
||||
@param('num_video_views', (1, 2))
|
||||
def test_proprioception(
|
||||
num_video_views
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
@ -693,11 +696,16 @@ def test_proprioception():
|
||||
dim_latent = 32,
|
||||
dim_proprio = 21,
|
||||
num_tasks = 4,
|
||||
num_video_views = num_video_views,
|
||||
num_discrete_actions = 4,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
video = torch.randn(2, 3, 10, 256, 256)
|
||||
if num_video_views > 1:
|
||||
video = torch.randn(2, num_video_views, 3, 10, 256, 256)
|
||||
else:
|
||||
video = torch.randn(2, 3, 10, 256, 256)
|
||||
|
||||
rewards = torch.randn(2, 10)
|
||||
proprio = torch.randn(2, 10, 21)
|
||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user