first pass through the kv cache for the time block in the dynamics model
This commit is contained in:
parent
a7e0c395c3
commit
ca244a290c
@ -861,7 +861,19 @@ def apply_rotations(
|
||||
rotations, # (h n d) | (n d)
|
||||
t # (b h n d)
|
||||
):
|
||||
heads, dtype = t.shape[1], t.dtype
|
||||
|
||||
heads, seq_len, dtype = *t.shape[1:3], t.dtype
|
||||
|
||||
rotations_seq_len = rotations.shape[-2]
|
||||
|
||||
# handle kv caching with rotations
|
||||
# todo - only fetch rotary embedding for one timestep
|
||||
|
||||
if rotations_seq_len > seq_len:
|
||||
rotations = rotations[-seq_len:]
|
||||
|
||||
# precision
|
||||
|
||||
t = t.float()
|
||||
|
||||
# handle gqa for rotary
|
||||
@ -1316,7 +1328,13 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
time_attn_kv_caches = []
|
||||
|
||||
has_kv_cache = exists(kv_cache)
|
||||
|
||||
if has_kv_cache:
|
||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||
|
||||
kv_cache = default(kv_cache, (None,))
|
||||
|
||||
iter_kv_cache = iter(kv_cache)
|
||||
|
||||
# attention
|
||||
@ -1362,6 +1380,10 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
out = self.final_norm(tokens)
|
||||
|
||||
if has_kv_cache:
|
||||
# just concat the past tokens back on for now, todo - clean up the logic
|
||||
out = cat((past_tokens, out), dim = 1)
|
||||
|
||||
if not return_kv_cache:
|
||||
return out
|
||||
|
||||
@ -2020,9 +2042,11 @@ class DynamicsWorldModel(Module):
|
||||
image_width = None,
|
||||
return_decoded_video = None,
|
||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||
use_time_kv_cache = True,
|
||||
return_rewards_per_frame = False,
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False
|
||||
return_log_probs_and_values = False,
|
||||
return_time_kv_cache = False
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
@ -2072,6 +2096,10 @@ class DynamicsWorldModel(Module):
|
||||
if return_rewards_per_frame:
|
||||
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
||||
|
||||
# handle maybe time kv cache
|
||||
|
||||
time_kv_cache = None
|
||||
|
||||
# while all the frames of the video (per latent) is not generated
|
||||
|
||||
while latents.shape[1] < time_steps:
|
||||
@ -2080,6 +2108,8 @@ class DynamicsWorldModel(Module):
|
||||
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
||||
|
||||
for step in range(num_steps):
|
||||
is_last_step = (step + 1) == num_steps
|
||||
|
||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
|
||||
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||
@ -2088,7 +2118,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred, agent_embed = self.forward(
|
||||
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
@ -2096,11 +2126,17 @@ class DynamicsWorldModel(Module):
|
||||
tasks = tasks,
|
||||
discrete_actions = decoded_discrete_actions,
|
||||
continuous_actions = decoded_continuous_actions,
|
||||
time_kv_cache = time_kv_cache,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_agent_tokens = True
|
||||
return_intermediates = True,
|
||||
)
|
||||
|
||||
if use_time_kv_cache and is_last_step:
|
||||
time_kv_cache = next_time_kv_cache
|
||||
|
||||
# unpack pred
|
||||
|
||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||
|
||||
# derive flow, based on whether in x-space or not
|
||||
@ -2186,7 +2222,12 @@ class DynamicsWorldModel(Module):
|
||||
# only return video or latent if not requesting anything else, for first stage training
|
||||
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
||||
return video if return_decoded_video else latents
|
||||
out = video if return_decoded_video else latents
|
||||
|
||||
if not return_time_kv_cache:
|
||||
return out
|
||||
|
||||
return out, time_kv_cache
|
||||
|
||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||
|
||||
@ -2209,7 +2250,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
gen.values = decoded_values
|
||||
|
||||
return gen
|
||||
if not return_time_kv_cache:
|
||||
return gen
|
||||
|
||||
return gen, time_kv_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -2226,10 +2270,11 @@ class DynamicsWorldModel(Module):
|
||||
continuous_actions = None, # (b t na) | (b t-1 na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
time_kv_cache = None,
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False,
|
||||
return_agent_tokens = False,
|
||||
return_intermediates = False,
|
||||
add_autoregressive_action_loss = False,
|
||||
update_loss_ema = None
|
||||
):
|
||||
@ -2397,6 +2442,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
|
||||
|
||||
elif self.action_embedder.has_actions:
|
||||
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
||||
|
||||
else:
|
||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||
|
||||
@ -2440,7 +2488,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# attention
|
||||
|
||||
tokens, time_kv_cache = self.transformer(tokens, return_kv_cache = True)
|
||||
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
|
||||
|
||||
# unpack
|
||||
|
||||
@ -2455,7 +2503,10 @@ class DynamicsWorldModel(Module):
|
||||
if not return_agent_tokens:
|
||||
return pred
|
||||
|
||||
return pred, agent_tokens
|
||||
if not return_time_kv_cache:
|
||||
return pred, agent_tokens
|
||||
|
||||
return pred, (agent_tokens, next_time_kv_cache)
|
||||
|
||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||
|
||||
@ -2463,13 +2514,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
|
||||
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_agent_tokens:
|
||||
if not return_intermediates:
|
||||
return pred
|
||||
|
||||
return pred, encoded_agent_tokens
|
||||
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
||||
|
||||
# determine the target for the loss
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.48"
|
||||
version = "0.0.49"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -12,6 +12,7 @@ import torch
|
||||
@param('condition_on_actions', (False, True))
|
||||
@param('num_residual_streams', (1, 4))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
@param('use_time_kv_cache', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
grouped_query_attn,
|
||||
@ -22,7 +23,8 @@ def test_e2e(
|
||||
signal_and_step_passed_in,
|
||||
condition_on_actions,
|
||||
num_residual_streams,
|
||||
add_reward_embed_to_agent_token
|
||||
add_reward_embed_to_agent_token,
|
||||
use_time_kv_cache
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
@ -108,7 +110,8 @@ def test_e2e(
|
||||
image_height = 128,
|
||||
image_width = 128,
|
||||
batch_size = 2,
|
||||
return_rewards_per_frame = True
|
||||
return_rewards_per_frame = True,
|
||||
use_time_kv_cache = use_time_kv_cache
|
||||
)
|
||||
|
||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user