given the VAT paper, add multiple video streams (third person, wrist camera, etc), geared for robotics. need to manage an extra dimension for multiple viewpoints
This commit is contained in:
parent
a9b728c611
commit
4ce82f34df
@ -45,6 +45,7 @@ from assoc_scan import AssocScan
|
|||||||
# vc - video channels
|
# vc - video channels
|
||||||
# vh, vw - video height and width
|
# vh, vw - video height and width
|
||||||
# mtp - multi token prediction length
|
# mtp - multi token prediction length
|
||||||
|
# v - video viewpoints
|
||||||
|
|
||||||
import einx
|
import einx
|
||||||
from einx import add, multiply
|
from einx import add, multiply
|
||||||
@ -1735,6 +1736,7 @@ class DynamicsWorldModel(Module):
|
|||||||
num_latent_tokens = None,
|
num_latent_tokens = None,
|
||||||
num_agents = 1,
|
num_agents = 1,
|
||||||
num_tasks = 0,
|
num_tasks = 0,
|
||||||
|
num_video_views = 1,
|
||||||
dim_proprio = None,
|
dim_proprio = None,
|
||||||
reward_encoder_kwargs: dict = dict(),
|
reward_encoder_kwargs: dict = dict(),
|
||||||
depth = 4,
|
depth = 4,
|
||||||
@ -1800,7 +1802,7 @@ class DynamicsWorldModel(Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.to_latent_pred = Sequential(
|
self.to_latent_pred = Sequential(
|
||||||
Reduce('b t n s d -> b t n d', 'mean'),
|
Reduce('b t v n s d -> b t v n d', 'mean'),
|
||||||
RMSNorm(dim),
|
RMSNorm(dim),
|
||||||
LinearNoBias(dim, dim_latent)
|
LinearNoBias(dim, dim_latent)
|
||||||
)
|
)
|
||||||
@ -1810,17 +1812,27 @@ class DynamicsWorldModel(Module):
|
|||||||
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
latent_tokens_to_space = num_latent_tokens // num_spatial_tokens
|
||||||
|
|
||||||
self.latents_to_spatial_tokens = Sequential(
|
self.latents_to_spatial_tokens = Sequential(
|
||||||
Rearrange('b t n d -> b t (n d)'),
|
Rearrange('... n d -> ... (n d)'),
|
||||||
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
Linear(num_latent_tokens * dim_latent, dim * num_spatial_tokens),
|
||||||
Rearrange('b t (s d) -> b t s d', s = num_spatial_tokens)
|
Rearrange('... (s d) -> ... s d', s = num_spatial_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.to_latent_pred = Sequential(
|
self.to_latent_pred = Sequential(
|
||||||
RMSNorm(dim),
|
RMSNorm(dim),
|
||||||
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
LinearNoBias(dim, dim_latent * latent_tokens_to_space),
|
||||||
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
Rearrange('b t v s (n d) -> b t v (s n) d', n = latent_tokens_to_space)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# number of video views, for robotics, which could have third person + wrist camera at least
|
||||||
|
|
||||||
|
assert num_video_views >= 1
|
||||||
|
self.video_has_multi_view = num_video_views > 1
|
||||||
|
|
||||||
|
self.num_video_views = num_video_views
|
||||||
|
|
||||||
|
if self.video_has_multi_view:
|
||||||
|
self.view_emb = nn.Parameter(torch.randn(num_video_views, dim) * 1e-2)
|
||||||
|
|
||||||
# proprioception
|
# proprioception
|
||||||
|
|
||||||
self.has_proprio = exists(dim_proprio)
|
self.has_proprio = exists(dim_proprio)
|
||||||
@ -2318,7 +2330,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# denoising
|
# denoising
|
||||||
# teacher forcing to start with
|
# teacher forcing to start with
|
||||||
|
|
||||||
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
latents = empty((batch_size, 0, self.num_video_views, *latent_shape), device = self.device)
|
||||||
|
|
||||||
past_latents_context_noise = latents.clone()
|
past_latents_context_noise = latents.clone()
|
||||||
|
|
||||||
@ -2354,7 +2366,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
curr_time_steps = latents.shape[1]
|
curr_time_steps = latents.shape[1]
|
||||||
|
|
||||||
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
||||||
|
|
||||||
noised_proprio = None
|
noised_proprio = None
|
||||||
|
|
||||||
@ -2365,12 +2377,12 @@ class DynamicsWorldModel(Module):
|
|||||||
is_last_step = (step + 1) == num_steps
|
is_last_step = (step + 1) == num_steps
|
||||||
|
|
||||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||||
|
|
||||||
# noising past latent context
|
# noising past latent context
|
||||||
|
|
||||||
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||||
|
|
||||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * v n d')
|
||||||
|
|
||||||
# handle proprio
|
# handle proprio
|
||||||
|
|
||||||
@ -2395,6 +2407,7 @@ class DynamicsWorldModel(Module):
|
|||||||
proprio = noised_proprio_with_context,
|
proprio = noised_proprio_with_context,
|
||||||
time_kv_cache = time_kv_cache,
|
time_kv_cache = time_kv_cache,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
|
latent_has_view_dim = True,
|
||||||
return_pred_only = True,
|
return_pred_only = True,
|
||||||
return_intermediates = True,
|
return_intermediates = True,
|
||||||
)
|
)
|
||||||
@ -2409,12 +2422,11 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# unpack pred
|
# unpack pred
|
||||||
|
|
||||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
_, pred = unpack(pred, pack_context_shape, 'b * v n d')
|
||||||
|
|
||||||
if has_proprio:
|
if has_proprio:
|
||||||
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
||||||
|
|
||||||
|
|
||||||
# derive flow, based on whether in x-space or not
|
# derive flow, based on whether in x-space or not
|
||||||
|
|
||||||
def denoise_step(pred, noised, signal_levels):
|
def denoise_step(pred, noised, signal_levels):
|
||||||
@ -2507,12 +2519,26 @@ class DynamicsWorldModel(Module):
|
|||||||
video = None
|
video = None
|
||||||
|
|
||||||
if return_decoded_video:
|
if return_decoded_video:
|
||||||
|
|
||||||
|
latents_for_video = rearrange(latents, 'b t v n d -> b v t n d')
|
||||||
|
latents_for_video, unpack_view = pack_one(latents_for_video, '* t n d')
|
||||||
|
|
||||||
video = self.video_tokenizer.decode(
|
video = self.video_tokenizer.decode(
|
||||||
latents,
|
latents_for_video,
|
||||||
height = image_height,
|
height = image_height,
|
||||||
width = image_width
|
width = image_width
|
||||||
)
|
)
|
||||||
|
|
||||||
|
video = unpack_view(video, '* t c vh vw')
|
||||||
|
|
||||||
|
# remove the lone view dimension
|
||||||
|
|
||||||
|
if not self.video_has_multi_view:
|
||||||
|
latents = rearrange(latents, 'b t 1 ... -> b t ...')
|
||||||
|
|
||||||
|
if exists(video):
|
||||||
|
video = rearrange(video, 'b 1 ... -> b ...')
|
||||||
|
|
||||||
# only return video or latent if not requesting anything else, for first stage training
|
# only return video or latent if not requesting anything else, for first stage training
|
||||||
|
|
||||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
||||||
@ -2553,8 +2579,8 @@ class DynamicsWorldModel(Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
video = None, # (b c t vh vw)
|
video = None, # (b v? c t vh vw)
|
||||||
latents = None, # (b t n d) | (b t d)
|
latents = None, # (b t v? n d) | (b t v? d)
|
||||||
signal_levels = None, # () | (b) | (b t)
|
signal_levels = None, # () | (b) | (b t)
|
||||||
step_sizes = None, # () | (b)
|
step_sizes = None, # () | (b)
|
||||||
step_sizes_log2 = None, # () | (b)
|
step_sizes_log2 = None, # () | (b)
|
||||||
@ -2572,21 +2598,39 @@ class DynamicsWorldModel(Module):
|
|||||||
return_all_losses = False,
|
return_all_losses = False,
|
||||||
return_intermediates = False,
|
return_intermediates = False,
|
||||||
add_autoregressive_action_loss = False,
|
add_autoregressive_action_loss = False,
|
||||||
update_loss_ema = None
|
update_loss_ema = None,
|
||||||
|
latent_has_view_dim = False
|
||||||
):
|
):
|
||||||
# handle video or latents
|
# handle video or latents
|
||||||
|
|
||||||
assert exists(video) ^ exists(latents)
|
assert exists(video) ^ exists(latents)
|
||||||
|
|
||||||
|
# standardize view dimension
|
||||||
|
|
||||||
|
if not self.video_has_multi_view:
|
||||||
|
if exists(video):
|
||||||
|
video = rearrange(video, 'b ... -> b 1 ...')
|
||||||
|
|
||||||
|
if exists(latents) and not latent_has_view_dim:
|
||||||
|
latents = rearrange(latents, 'b t ... -> b t 1 ...')
|
||||||
|
|
||||||
|
# if raw video passed in, tokenize
|
||||||
|
|
||||||
if exists(video):
|
if exists(video):
|
||||||
|
assert video.ndim == 6
|
||||||
|
|
||||||
|
video, unpack_views = pack_one(video, '* c t vh vw')
|
||||||
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
assert exists(self.video_tokenizer), 'video_tokenizer must be passed in if training from raw video on dynamics model'
|
||||||
|
|
||||||
latents = self.video_tokenizer.tokenize(video)
|
latents = self.video_tokenizer.tokenize(video)
|
||||||
|
latents = unpack_views(latents, '* t n d')
|
||||||
|
latents = rearrange(latents, 'b v t n d -> b t v n d')
|
||||||
|
|
||||||
if latents.ndim == 3:
|
if latents.ndim == 4:
|
||||||
latents = rearrange(latents, 'b t d -> b t 1 d') # 1 latent edge case
|
latents = rearrange(latents, 'b t v d -> b t v 1 d') # 1 latent edge case
|
||||||
|
|
||||||
assert latents.shape[-2:] == self.latent_shape
|
assert latents.shape[-2:] == self.latent_shape
|
||||||
|
assert latents.shape[2] == self.num_video_views
|
||||||
|
|
||||||
# variables
|
# variables
|
||||||
|
|
||||||
@ -2769,6 +2813,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
||||||
|
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.70"
|
version = "0.0.71"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -668,7 +668,10 @@ def test_online_rl(
|
|||||||
|
|
||||||
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
||||||
|
|
||||||
def test_proprioception():
|
@param('num_video_views', (1, 2))
|
||||||
|
def test_proprioception(
|
||||||
|
num_video_views
|
||||||
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
tokenizer = VideoTokenizer(
|
||||||
@ -693,11 +696,16 @@ def test_proprioception():
|
|||||||
dim_latent = 32,
|
dim_latent = 32,
|
||||||
dim_proprio = 21,
|
dim_proprio = 21,
|
||||||
num_tasks = 4,
|
num_tasks = 4,
|
||||||
|
num_video_views = num_video_views,
|
||||||
num_discrete_actions = 4,
|
num_discrete_actions = 4,
|
||||||
num_residual_streams = 1
|
num_residual_streams = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
video = torch.randn(2, 3, 10, 256, 256)
|
if num_video_views > 1:
|
||||||
|
video = torch.randn(2, num_video_views, 3, 10, 256, 256)
|
||||||
|
else:
|
||||||
|
video = torch.randn(2, 3, 10, 256, 256)
|
||||||
|
|
||||||
rewards = torch.randn(2, 10)
|
rewards = torch.randn(2, 10)
|
||||||
proprio = torch.randn(2, 10, 21)
|
proprio = torch.randn(2, 10, 21)
|
||||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user