|
|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
from math import ceil, log2
|
|
|
|
|
@ -70,13 +71,19 @@ except ImportError:
|
|
|
|
|
|
|
|
|
|
LinearNoBias = partial(Linear, bias = False)
|
|
|
|
|
|
|
|
|
|
VideoTokenizerIntermediates = namedtuple('VideoTokenizerIntermediates', ('losses', 'recon'))
|
|
|
|
|
|
|
|
|
|
TokenizerLosses = namedtuple('TokenizerLosses', ('recon', 'lpips', 'time_decorr', 'space_decorr'))
|
|
|
|
|
|
|
|
|
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions'))
|
|
|
|
|
WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_actions', 'continuous_actions', 'state_pred'))
|
|
|
|
|
|
|
|
|
|
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
|
|
|
|
|
|
|
|
|
|
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
|
|
|
|
|
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
|
|
|
|
|
|
|
|
|
|
Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state'))
|
|
|
|
|
|
|
|
|
|
Embeds = namedtuple('Embeds', ['agent', 'state_pred'])
|
|
|
|
|
|
|
|
|
|
MaybeTensor = Tensor | None
|
|
|
|
|
|
|
|
|
|
@ -95,7 +102,7 @@ class Experience:
|
|
|
|
|
lens: MaybeTensor = None
|
|
|
|
|
is_truncated: MaybeTensor = None
|
|
|
|
|
agent_index: int = 0
|
|
|
|
|
is_from_world_model: bool = True
|
|
|
|
|
is_from_world_model: bool | Tensor = True
|
|
|
|
|
|
|
|
|
|
def cpu(self):
|
|
|
|
|
return self.to(torch.device('cpu'))
|
|
|
|
|
@ -123,6 +130,9 @@ def combine_experiences(
|
|
|
|
|
if not exists(exp.is_truncated):
|
|
|
|
|
exp.is_truncated = full((batch,), True, device = device)
|
|
|
|
|
|
|
|
|
|
if isinstance(exp.is_from_world_model, bool):
|
|
|
|
|
exp.is_from_world_model = tensor(exp.is_from_world_model)
|
|
|
|
|
|
|
|
|
|
# convert to dictionary
|
|
|
|
|
|
|
|
|
|
exps_dict = [asdict(exp) for exp in exps]
|
|
|
|
|
@ -144,11 +154,15 @@ def combine_experiences(
|
|
|
|
|
|
|
|
|
|
for field_values in all_field_values:
|
|
|
|
|
|
|
|
|
|
if is_tensor(first(field_values)):
|
|
|
|
|
first_value = first(field_values)
|
|
|
|
|
|
|
|
|
|
if is_tensor(first_value):
|
|
|
|
|
|
|
|
|
|
field_values = pad_tensors_at_dim_to_max_len(field_values, dims = (1, 2))
|
|
|
|
|
|
|
|
|
|
new_field_value = cat(field_values)
|
|
|
|
|
cat_or_stack = cat if first_value.ndim > 0 else stack
|
|
|
|
|
|
|
|
|
|
new_field_value = cat_or_stack(field_values)
|
|
|
|
|
else:
|
|
|
|
|
new_field_value = first(list(set(field_values)))
|
|
|
|
|
|
|
|
|
|
@ -1453,6 +1467,29 @@ class SwiGLUFeedforward(Module):
|
|
|
|
|
|
|
|
|
|
return self.proj_out(x)
|
|
|
|
|
|
|
|
|
|
# rnn
|
|
|
|
|
|
|
|
|
|
class GRULayer(Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim,
|
|
|
|
|
dim_out
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.norm = nn.RMSNorm(dim)
|
|
|
|
|
self.gru = nn.GRU(dim, dim_out, batch_first = True)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x,
|
|
|
|
|
prev_hiddens = None
|
|
|
|
|
):
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
|
|
|
|
|
x, hiddens = self.gru(x, prev_hiddens)
|
|
|
|
|
|
|
|
|
|
return x, hiddens
|
|
|
|
|
|
|
|
|
|
# axial space time transformer
|
|
|
|
|
|
|
|
|
|
class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
@ -1471,7 +1508,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
special_attend_only_itself = False, # this is set to True for the video tokenizer decoder (latents can only attend to itself while spatial modalities attend to the latents and everything)
|
|
|
|
|
final_norm = True,
|
|
|
|
|
value_residual = True, # https://arxiv.org/abs/2410.17897 - but with learned mixing from OSS
|
|
|
|
|
rnn_time = False
|
|
|
|
|
rnn_time = True
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert depth >= time_block_every, f'depth must be at least {time_block_every}'
|
|
|
|
|
@ -1531,10 +1568,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
|
|
|
|
|
]))
|
|
|
|
|
|
|
|
|
|
rnn_layers.append(ModuleList([
|
|
|
|
|
nn.RMSNorm(dim),
|
|
|
|
|
nn.GRU(dim, dim, batch_first = True)
|
|
|
|
|
]) if is_time_block and rnn_time else None)
|
|
|
|
|
rnn_layers.append(hyper_conn(branch = GRULayer(dim, dim)) if is_time_block and rnn_time else None)
|
|
|
|
|
|
|
|
|
|
self.layers = ModuleList(layers)
|
|
|
|
|
self.rnn_layers = ModuleList(rnn_layers)
|
|
|
|
|
@ -1560,8 +1594,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
tokens, # (b t s d)
|
|
|
|
|
kv_cache: Tensor | None = None, # (y 2 b h t d)
|
|
|
|
|
tokens, # (b t s d)
|
|
|
|
|
cache: TransformerIntermediates | None = None,
|
|
|
|
|
return_intermediates = False
|
|
|
|
|
|
|
|
|
|
): # (b t s d) | (y 2 b h t d)
|
|
|
|
|
@ -1570,6 +1604,14 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
|
|
|
|
|
assert tokens.ndim == 4
|
|
|
|
|
|
|
|
|
|
# destruct intermediates to cache for attention and rnn respectively
|
|
|
|
|
|
|
|
|
|
kv_cache = rnn_prev_hiddens = None
|
|
|
|
|
|
|
|
|
|
if exists(cache):
|
|
|
|
|
kv_cache = cache.next_kv_cache
|
|
|
|
|
rnn_prev_hiddens = cache.next_rnn_hiddens
|
|
|
|
|
|
|
|
|
|
# attend functions for space and time
|
|
|
|
|
|
|
|
|
|
has_kv_cache = exists(kv_cache)
|
|
|
|
|
@ -1584,6 +1626,7 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
# prepare cache
|
|
|
|
|
|
|
|
|
|
time_attn_kv_caches = []
|
|
|
|
|
rnn_hiddens = []
|
|
|
|
|
|
|
|
|
|
if has_kv_cache:
|
|
|
|
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
|
|
|
|
@ -1598,6 +1641,10 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
|
|
|
|
|
iter_kv_cache = iter(kv_cache)
|
|
|
|
|
|
|
|
|
|
rnn_prev_hiddens = default(rnn_prev_hiddens, (None,))
|
|
|
|
|
|
|
|
|
|
iter_rnn_prev_hiddens = iter(rnn_prev_hiddens)
|
|
|
|
|
|
|
|
|
|
# rotary
|
|
|
|
|
|
|
|
|
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
|
|
|
|
@ -1618,22 +1665,21 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
|
|
|
|
|
tokens = self.expand_streams(tokens)
|
|
|
|
|
|
|
|
|
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn_modules, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
|
|
|
|
|
for (pre_attn_rearrange, post_attn_rearrange, attn, ff), maybe_rnn, layer_is_time in zip(self.layers, self.rnn_layers, self.is_time):
|
|
|
|
|
|
|
|
|
|
tokens = pre_attn_rearrange(tokens)
|
|
|
|
|
|
|
|
|
|
# maybe rnn for time
|
|
|
|
|
|
|
|
|
|
if layer_is_time and exists(maybe_rnn_modules):
|
|
|
|
|
rnn_prenorm, rnn = maybe_rnn_modules
|
|
|
|
|
if layer_is_time and exists(maybe_rnn):
|
|
|
|
|
|
|
|
|
|
rnn_input, inverse_pack_time = pack_one(tokens, '* t d')
|
|
|
|
|
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
|
|
|
|
|
|
|
|
|
|
rnn_out, rnn_hiddens = rnn(rnn_prenorm(rnn_input)) # todo, handle rnn cache
|
|
|
|
|
tokens, layer_rnn_hiddens = maybe_rnn(tokens, next(iter_rnn_prev_hiddens, None)) # todo, handle rnn cache
|
|
|
|
|
|
|
|
|
|
rnn_out = inverse_pack_time(rnn_out)
|
|
|
|
|
tokens = inverse_pack_batch(tokens)
|
|
|
|
|
|
|
|
|
|
tokens = rnn_out + tokens
|
|
|
|
|
rnn_hiddens.append(layer_rnn_hiddens)
|
|
|
|
|
|
|
|
|
|
# when is a axial time attention block, should be causal
|
|
|
|
|
|
|
|
|
|
@ -1691,7 +1737,8 @@ class AxialSpaceTimeTransformer(Module):
|
|
|
|
|
intermediates = TransformerIntermediates(
|
|
|
|
|
stack(time_attn_kv_caches),
|
|
|
|
|
safe_stack(normed_time_attn_inputs),
|
|
|
|
|
safe_stack(normed_space_attn_inputs)
|
|
|
|
|
safe_stack(normed_space_attn_inputs),
|
|
|
|
|
safe_stack(rnn_hiddens)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return out, intermediates
|
|
|
|
|
@ -1723,11 +1770,6 @@ class VideoTokenizer(Module):
|
|
|
|
|
encoder_add_decor_aux_loss = False,
|
|
|
|
|
decor_auxx_loss_weight = 0.1,
|
|
|
|
|
decorr_sample_frac = 0.25,
|
|
|
|
|
nd_rotary_kwargs: dict = dict(
|
|
|
|
|
rope_min_freq = 1.,
|
|
|
|
|
rope_max_freq = 10000.,
|
|
|
|
|
rope_p_zero_freqs = 0.
|
|
|
|
|
),
|
|
|
|
|
num_residual_streams = 1,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@ -1900,11 +1942,23 @@ class VideoTokenizer(Module):
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
video, # (b c t h w)
|
|
|
|
|
video_or_image, # (b c t h w) | (b c h w)
|
|
|
|
|
return_latents = False,
|
|
|
|
|
mask_patches = None,
|
|
|
|
|
return_all_losses = False
|
|
|
|
|
return_intermediates = False,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
# handle image pretraining
|
|
|
|
|
|
|
|
|
|
is_image = video_or_image.ndim == 4
|
|
|
|
|
|
|
|
|
|
if is_image:
|
|
|
|
|
video = rearrange(video_or_image, 'b c h w -> b c 1 h w')
|
|
|
|
|
else:
|
|
|
|
|
video = video_or_image
|
|
|
|
|
|
|
|
|
|
# shapes
|
|
|
|
|
|
|
|
|
|
batch, _, time, height, width = video.shape
|
|
|
|
|
patch_size, device = self.patch_size, video.device
|
|
|
|
|
|
|
|
|
|
@ -1944,7 +1998,7 @@ class VideoTokenizer(Module):
|
|
|
|
|
|
|
|
|
|
# encoder attention
|
|
|
|
|
|
|
|
|
|
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
|
|
|
|
|
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs, _) = self.encoder_transformer(tokens, return_intermediates = True)
|
|
|
|
|
|
|
|
|
|
# latent bottleneck
|
|
|
|
|
|
|
|
|
|
@ -1984,12 +2038,21 @@ class VideoTokenizer(Module):
|
|
|
|
|
space_decorr_loss * self.decorr_aux_loss_weight
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not return_all_losses:
|
|
|
|
|
if not return_intermediates:
|
|
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
|
losses = (recon_loss, lpips_loss, decorr_loss)
|
|
|
|
|
losses = TokenizerLosses(recon_loss, lpips_loss, time_decorr_loss, space_decorr_loss)
|
|
|
|
|
|
|
|
|
|
return total_loss, TokenizerLosses(*losses)
|
|
|
|
|
out = losses
|
|
|
|
|
|
|
|
|
|
# handle returning of reconstructed, and image pretraining
|
|
|
|
|
|
|
|
|
|
if is_image:
|
|
|
|
|
recon_video = rearrange(recon_video, 'b c 1 h w -> b c h w')
|
|
|
|
|
|
|
|
|
|
out = (losses, recon_video)
|
|
|
|
|
|
|
|
|
|
return total_loss, out
|
|
|
|
|
|
|
|
|
|
# dynamics model, axial space-time transformer
|
|
|
|
|
|
|
|
|
|
@ -2017,15 +2080,18 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
attn_dim_head = 64,
|
|
|
|
|
attn_softclamp_value = 50.,
|
|
|
|
|
ff_kwargs: dict = dict(),
|
|
|
|
|
use_time_rnn = True,
|
|
|
|
|
loss_weight_fn: Callable = ramp_weight,
|
|
|
|
|
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
|
|
|
|
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
|
|
|
|
add_reward_embed_to_agent_token = False,
|
|
|
|
|
add_reward_embed_dropout = 0.1,
|
|
|
|
|
add_state_pred_head = False,
|
|
|
|
|
state_pred_loss_weight = 0.1,
|
|
|
|
|
state_entropy_bonus_weight = 0.05,
|
|
|
|
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
|
|
|
|
num_continuous_actions = 0,
|
|
|
|
|
continuous_norm_stats = None,
|
|
|
|
|
multi_token_pred_len = 8,
|
|
|
|
|
multi_token_pred_len = 8, # they do multi-token prediction of 8 steps forward
|
|
|
|
|
value_head_mlp_depth = 3,
|
|
|
|
|
policy_head_mlp_depth = 3,
|
|
|
|
|
latent_flow_loss_weight = 1.,
|
|
|
|
|
@ -2145,6 +2211,25 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
|
|
|
|
self.loss_weight_fn = loss_weight_fn
|
|
|
|
|
|
|
|
|
|
# state prediction, for state entropy bonus
|
|
|
|
|
|
|
|
|
|
self.add_state_pred_head = add_state_pred_head
|
|
|
|
|
self.state_pred_loss_weight = state_pred_loss_weight
|
|
|
|
|
|
|
|
|
|
self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0.
|
|
|
|
|
|
|
|
|
|
if self.should_pred_state:
|
|
|
|
|
self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2)
|
|
|
|
|
|
|
|
|
|
self.to_state_pred = Sequential(
|
|
|
|
|
RMSNorm(dim),
|
|
|
|
|
nn.Linear(dim, num_latent_tokens * dim_latent * 2),
|
|
|
|
|
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.state_entropy_bonus_weight = state_entropy_bonus_weight
|
|
|
|
|
self.add_state_entropy_bonus = self.should_pred_state and state_entropy_bonus_weight > 0.
|
|
|
|
|
|
|
|
|
|
# reinforcement related
|
|
|
|
|
|
|
|
|
|
# they sum all the actions into a single token
|
|
|
|
|
@ -2235,6 +2320,7 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
num_special_spatial_tokens = num_agents,
|
|
|
|
|
time_block_every = time_block_every,
|
|
|
|
|
final_norm = False,
|
|
|
|
|
rnn_time = use_time_rnn,
|
|
|
|
|
**transformer_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -2378,10 +2464,10 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
env,
|
|
|
|
|
seed = None,
|
|
|
|
|
agent_index = 0,
|
|
|
|
|
step_size = 4,
|
|
|
|
|
num_steps = 4,
|
|
|
|
|
max_timesteps = 16,
|
|
|
|
|
env_is_vectorized = False,
|
|
|
|
|
use_time_kv_cache = True,
|
|
|
|
|
use_time_cache = True,
|
|
|
|
|
store_agent_embed = True,
|
|
|
|
|
store_old_action_unembeds = True,
|
|
|
|
|
):
|
|
|
|
|
@ -2418,9 +2504,14 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
episode_lens = full((batch,), 0, device = device)
|
|
|
|
|
|
|
|
|
|
# derive step size
|
|
|
|
|
|
|
|
|
|
assert divisible_by(self.max_steps, num_steps)
|
|
|
|
|
step_size = self.max_steps // num_steps
|
|
|
|
|
|
|
|
|
|
# maybe time kv cache
|
|
|
|
|
|
|
|
|
|
time_kv_cache = None
|
|
|
|
|
time_cache = None
|
|
|
|
|
|
|
|
|
|
step_index = 0
|
|
|
|
|
|
|
|
|
|
@ -2429,14 +2520,14 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
|
|
|
|
|
|
|
|
_, (agent_embed, next_time_kv_cache) = self.forward(
|
|
|
|
|
_, (embeds, next_time_cache) = self.forward(
|
|
|
|
|
latents = latents,
|
|
|
|
|
signal_levels = self.max_steps - 1,
|
|
|
|
|
step_sizes = step_size,
|
|
|
|
|
rewards = rewards,
|
|
|
|
|
discrete_actions = discrete_actions,
|
|
|
|
|
continuous_actions = continuous_actions,
|
|
|
|
|
time_kv_cache = time_kv_cache,
|
|
|
|
|
time_cache = time_cache,
|
|
|
|
|
latent_is_noised = True,
|
|
|
|
|
return_pred_only = True,
|
|
|
|
|
return_intermediates = True
|
|
|
|
|
@ -2444,11 +2535,13 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
# time kv cache
|
|
|
|
|
|
|
|
|
|
if use_time_kv_cache:
|
|
|
|
|
time_kv_cache = next_time_kv_cache
|
|
|
|
|
if use_time_cache:
|
|
|
|
|
time_cache = next_time_cache
|
|
|
|
|
|
|
|
|
|
# get one agent
|
|
|
|
|
|
|
|
|
|
agent_embed = embeds.agent
|
|
|
|
|
|
|
|
|
|
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
|
|
|
|
|
|
|
|
|
# values
|
|
|
|
|
@ -2503,6 +2596,17 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
elif len(env_step_out) == 5:
|
|
|
|
|
next_frame, reward, terminated, truncated, info = env_step_out
|
|
|
|
|
|
|
|
|
|
# maybe add state entropy bonus
|
|
|
|
|
|
|
|
|
|
if self.add_state_entropy_bonus:
|
|
|
|
|
state_pred_token = embeds.state_pred
|
|
|
|
|
|
|
|
|
|
state_pred = self.to_state_pred(state_pred_token)
|
|
|
|
|
|
|
|
|
|
state_pred_log_variance = state_pred[..., 1].sum()
|
|
|
|
|
|
|
|
|
|
reward = reward + state_pred_log_variance * self.state_entropy_bonus_weight
|
|
|
|
|
|
|
|
|
|
# update episode lens
|
|
|
|
|
|
|
|
|
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
|
|
|
|
@ -2681,7 +2785,7 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
with world_model_forward_context():
|
|
|
|
|
_, (agent_embeds, _) = self.forward(
|
|
|
|
|
_, (embeds, _) = self.forward(
|
|
|
|
|
latents = latents,
|
|
|
|
|
signal_levels = self.max_steps - 1,
|
|
|
|
|
step_sizes = step_size,
|
|
|
|
|
@ -2693,7 +2797,7 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
return_intermediates = True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
agent_embeds = agent_embeds[..., agent_index, :]
|
|
|
|
|
agent_embeds = embeds.agent[..., agent_index, :]
|
|
|
|
|
|
|
|
|
|
# maybe detach agent embed
|
|
|
|
|
|
|
|
|
|
@ -2837,13 +2941,13 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
image_width = None,
|
|
|
|
|
return_decoded_video = None,
|
|
|
|
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
|
|
|
|
time_kv_cache: Tensor | None = None,
|
|
|
|
|
use_time_kv_cache = True,
|
|
|
|
|
time_cache: Tensor | None = None,
|
|
|
|
|
use_time_cache = True,
|
|
|
|
|
return_rewards_per_frame = False,
|
|
|
|
|
return_agent_actions = False,
|
|
|
|
|
return_log_probs_and_values = False,
|
|
|
|
|
return_for_policy_optimization = False,
|
|
|
|
|
return_time_kv_cache = False,
|
|
|
|
|
return_time_cache = False,
|
|
|
|
|
store_agent_embed = True,
|
|
|
|
|
store_old_action_unembeds = True
|
|
|
|
|
|
|
|
|
|
@ -2932,7 +3036,7 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
|
|
|
|
|
|
|
|
|
take_extra_step = (
|
|
|
|
|
use_time_kv_cache or
|
|
|
|
|
use_time_cache or
|
|
|
|
|
return_rewards_per_frame or
|
|
|
|
|
store_agent_embed or
|
|
|
|
|
return_agent_actions
|
|
|
|
|
@ -2973,7 +3077,7 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
|
|
|
|
|
|
|
|
|
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
|
|
|
|
pred, (embeds, next_time_cache) = self.forward(
|
|
|
|
|
latents = noised_latent_with_context,
|
|
|
|
|
signal_levels = signal_levels_with_context,
|
|
|
|
|
step_sizes = step_size,
|
|
|
|
|
@ -2983,15 +3087,15 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
discrete_actions = decoded_discrete_actions,
|
|
|
|
|
continuous_actions = decoded_continuous_actions,
|
|
|
|
|
proprio = noised_proprio_with_context,
|
|
|
|
|
time_kv_cache = time_kv_cache,
|
|
|
|
|
time_cache = time_cache,
|
|
|
|
|
latent_is_noised = True,
|
|
|
|
|
latent_has_view_dim = True,
|
|
|
|
|
return_pred_only = True,
|
|
|
|
|
return_intermediates = True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if use_time_kv_cache and is_last_step:
|
|
|
|
|
time_kv_cache = next_time_kv_cache
|
|
|
|
|
if use_time_cache and is_last_step:
|
|
|
|
|
time_cache = next_time_cache
|
|
|
|
|
|
|
|
|
|
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
|
|
|
|
|
|
|
|
|
@ -3000,8 +3104,10 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
# maybe proprio
|
|
|
|
|
|
|
|
|
|
if has_proprio:
|
|
|
|
|
pred, pred_proprio = pred
|
|
|
|
|
# maybe proprio
|
|
|
|
|
|
|
|
|
|
pred_proprio = pred.proprioception
|
|
|
|
|
pred = pred.flow
|
|
|
|
|
|
|
|
|
|
# unpack pred
|
|
|
|
|
|
|
|
|
|
@ -3038,6 +3144,8 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
|
|
|
|
|
|
|
|
|
if return_rewards_per_frame:
|
|
|
|
|
agent_embed = embeds.agent
|
|
|
|
|
|
|
|
|
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
|
|
|
|
|
|
|
|
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
|
|
|
|
@ -3048,6 +3156,8 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
# maybe store agent embed
|
|
|
|
|
|
|
|
|
|
if store_agent_embed:
|
|
|
|
|
agent_embed = embeds.agent
|
|
|
|
|
|
|
|
|
|
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
|
|
|
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
|
|
|
|
|
|
|
|
@ -3140,10 +3250,10 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
|
|
|
|
out = video if return_decoded_video else latents
|
|
|
|
|
|
|
|
|
|
if not return_time_kv_cache:
|
|
|
|
|
if not return_time_cache:
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
return out, time_kv_cache
|
|
|
|
|
return out, time_cache
|
|
|
|
|
|
|
|
|
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
|
|
|
|
|
|
|
|
@ -3173,10 +3283,10 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
gen.values = decoded_values
|
|
|
|
|
|
|
|
|
|
if not return_time_kv_cache:
|
|
|
|
|
if not return_time_cache:
|
|
|
|
|
return gen
|
|
|
|
|
|
|
|
|
|
return gen, time_kv_cache
|
|
|
|
|
return gen, time_cache
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
@ -3195,7 +3305,7 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
discrete_action_types = None, # (na)
|
|
|
|
|
continuous_action_types = None, # (na)
|
|
|
|
|
proprio = None, # (b t dp)
|
|
|
|
|
time_kv_cache = None,
|
|
|
|
|
time_cache = None,
|
|
|
|
|
return_pred_only = False,
|
|
|
|
|
latent_is_noised = False,
|
|
|
|
|
return_all_losses = False,
|
|
|
|
|
@ -3345,9 +3455,13 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
|
|
|
|
|
|
|
|
|
# empty token
|
|
|
|
|
|
|
|
|
|
empty_token = agent_tokens[:, :, 0:0]
|
|
|
|
|
|
|
|
|
|
# maybe reward tokens
|
|
|
|
|
|
|
|
|
|
reward_tokens = agent_tokens[:, :, 0:0]
|
|
|
|
|
reward_tokens = empty_token
|
|
|
|
|
|
|
|
|
|
if exists(rewards):
|
|
|
|
|
two_hot_encoding = self.reward_encoder(rewards)
|
|
|
|
|
@ -3387,6 +3501,13 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
else:
|
|
|
|
|
noised_proprio = proprio
|
|
|
|
|
|
|
|
|
|
# maybe state prediction token
|
|
|
|
|
|
|
|
|
|
if self.should_pred_state:
|
|
|
|
|
state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time)
|
|
|
|
|
else:
|
|
|
|
|
state_pred_token = empty_token
|
|
|
|
|
|
|
|
|
|
# maybe create the action tokens
|
|
|
|
|
|
|
|
|
|
if exists(discrete_actions) or exists(continuous_actions):
|
|
|
|
|
@ -3411,11 +3532,11 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
|
|
|
|
action_tokens = empty_token # else empty off agent tokens
|
|
|
|
|
|
|
|
|
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
|
|
|
|
|
|
|
|
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
|
|
|
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
|
|
|
|
|
|
|
|
|
# latents to spatial tokens
|
|
|
|
|
|
|
|
|
|
@ -3464,15 +3585,15 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
# pack to tokens for attending
|
|
|
|
|
|
|
|
|
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
|
|
|
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
|
|
|
|
|
|
|
|
|
# attention
|
|
|
|
|
|
|
|
|
|
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
|
|
|
|
|
tokens, intermediates = self.transformer(tokens, cache = time_cache, return_intermediates = True)
|
|
|
|
|
|
|
|
|
|
# unpack
|
|
|
|
|
|
|
|
|
|
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
|
|
|
|
flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
|
|
|
|
|
|
|
|
|
# pooling
|
|
|
|
|
|
|
|
|
|
@ -3484,42 +3605,54 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
if self.has_proprio:
|
|
|
|
|
pred_proprio = self.to_proprio_pred(proprio_token)
|
|
|
|
|
else:
|
|
|
|
|
pred_proprio = None
|
|
|
|
|
|
|
|
|
|
pred = (pred, pred_proprio)
|
|
|
|
|
# maybe state pred
|
|
|
|
|
|
|
|
|
|
if self.should_pred_state:
|
|
|
|
|
pred_state = self.to_state_pred(state_pred_token)
|
|
|
|
|
else:
|
|
|
|
|
pred_state = None
|
|
|
|
|
|
|
|
|
|
# returning
|
|
|
|
|
|
|
|
|
|
predictions = Predictions(pred, pred_proprio, pred_state)
|
|
|
|
|
|
|
|
|
|
embeds = Embeds(agent_tokens, state_pred_token)
|
|
|
|
|
|
|
|
|
|
if not return_agent_tokens:
|
|
|
|
|
return pred
|
|
|
|
|
return predictions
|
|
|
|
|
|
|
|
|
|
if not return_time_kv_cache:
|
|
|
|
|
return pred, agent_tokens
|
|
|
|
|
if not return_time_cache:
|
|
|
|
|
return predictions, embeds
|
|
|
|
|
|
|
|
|
|
return pred, (agent_tokens, next_time_kv_cache)
|
|
|
|
|
return predictions, (embeds, intermediates)
|
|
|
|
|
|
|
|
|
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
|
|
|
|
|
|
|
|
|
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
|
|
|
|
_get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
|
|
|
|
|
|
|
|
|
# forward the network
|
|
|
|
|
|
|
|
|
|
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
|
|
|
|
pred, (embeds, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
|
|
|
|
|
|
|
|
|
if return_pred_only:
|
|
|
|
|
if not return_intermediates:
|
|
|
|
|
return pred
|
|
|
|
|
|
|
|
|
|
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
|
|
|
|
return pred, (embeds, intermediates)
|
|
|
|
|
|
|
|
|
|
# pack the predictions to calculate flow for different modalities all at once
|
|
|
|
|
|
|
|
|
|
if self.has_proprio:
|
|
|
|
|
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
|
|
|
|
packed_pred, for_flow_loss_packed_shape = pack((pred.flow, pred.proprioception), 'b t *')
|
|
|
|
|
|
|
|
|
|
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
|
|
|
|
data, _ = pack((latents, proprio), 'b t *')
|
|
|
|
|
noise, _ = pack((noise, proprio_noise), 'b t *')
|
|
|
|
|
else:
|
|
|
|
|
packed_pred = pred.flow
|
|
|
|
|
noised = noised_latents
|
|
|
|
|
data = latents
|
|
|
|
|
|
|
|
|
|
@ -3538,9 +3671,10 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
pred = fn(noised, noised_proprio, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
if self.has_proprio:
|
|
|
|
|
pred, _ = pack(pred, 'b t *')
|
|
|
|
|
packed_flow, _ = pack((pred.flow, pred.proprioception), 'b t *')
|
|
|
|
|
return packed_flow
|
|
|
|
|
|
|
|
|
|
return pred
|
|
|
|
|
return pred.flow
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
|
|
|
|
@ -3607,12 +3741,12 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
# need to convert x-space to v-space
|
|
|
|
|
|
|
|
|
|
if is_x_space:
|
|
|
|
|
pred = (pred - noised) / (1. - first_times)
|
|
|
|
|
packed_pred = (packed_pred - noised) / (1. - first_times)
|
|
|
|
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
|
|
|
|
|
|
|
|
|
# mse loss
|
|
|
|
|
|
|
|
|
|
flow_losses = F.mse_loss(pred, pred_target, reduction = 'none')
|
|
|
|
|
flow_losses = F.mse_loss(packed_pred, pred_target, reduction = 'none')
|
|
|
|
|
|
|
|
|
|
flow_losses = flow_losses * maybe_shortcut_loss_weight # handle the (1-t)^2 in eq(7)
|
|
|
|
|
|
|
|
|
|
@ -3644,6 +3778,8 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
if exists(rewards):
|
|
|
|
|
|
|
|
|
|
encoded_agent_tokens = embeds.agent
|
|
|
|
|
|
|
|
|
|
if rewards.ndim == 2: # (b t)
|
|
|
|
|
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
|
|
|
|
|
|
|
|
|
@ -3664,6 +3800,18 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
else:
|
|
|
|
|
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
|
|
|
|
|
|
|
|
|
# maybe autoregressive state prediction loss
|
|
|
|
|
|
|
|
|
|
state_pred_loss = self.zero
|
|
|
|
|
|
|
|
|
|
if self.should_pred_state:
|
|
|
|
|
pred_latent, latent_to_pred = pred.state[:, :-1], latents[:, 1:]
|
|
|
|
|
|
|
|
|
|
pred_latent_mean, pred_latent_log_var = pred_latent.unbind(dim = -1)
|
|
|
|
|
pred_latent_var = pred_latent_log_var.exp()
|
|
|
|
|
|
|
|
|
|
state_pred_loss = F.gaussian_nll_loss(pred_latent_mean, latent_to_pred, var = pred_latent_var)
|
|
|
|
|
|
|
|
|
|
# maybe autoregressive action loss
|
|
|
|
|
|
|
|
|
|
discrete_action_loss = self.zero
|
|
|
|
|
@ -3731,7 +3879,7 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
|
|
|
|
|
# handle loss normalization
|
|
|
|
|
|
|
|
|
|
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss)
|
|
|
|
|
losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss, state_pred_loss)
|
|
|
|
|
|
|
|
|
|
if exists(self.flow_loss_normalizer):
|
|
|
|
|
flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema)
|
|
|
|
|
@ -3751,7 +3899,8 @@ class DynamicsWorldModel(Module):
|
|
|
|
|
flow_loss * self.latent_flow_loss_weight +
|
|
|
|
|
(reward_loss * self.reward_loss_weight).sum() +
|
|
|
|
|
(discrete_action_loss * self.discrete_action_loss_weight).sum() +
|
|
|
|
|
(continuous_action_loss * self.continuous_action_loss_weight).sum()
|
|
|
|
|
(continuous_action_loss * self.continuous_action_loss_weight).sum() +
|
|
|
|
|
(state_pred_loss * self.state_pred_loss_weight)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not return_all_losses:
|
|
|
|
|
|