first pass through the kv cache for the time block in the dynamics model

This commit is contained in:
lucidrains 2025-10-20 12:25:50 -07:00
parent a7e0c395c3
commit ca244a290c
3 changed files with 69 additions and 15 deletions

View File

@ -861,7 +861,19 @@ def apply_rotations(
rotations, # (h n d) | (n d)
t # (b h n d)
):
heads, dtype = t.shape[1], t.dtype
heads, seq_len, dtype = *t.shape[1:3], t.dtype
rotations_seq_len = rotations.shape[-2]
# handle kv caching with rotations
# todo - only fetch rotary embedding for one timestep
if rotations_seq_len > seq_len:
rotations = rotations[-seq_len:]
# precision
t = t.float()
# handle gqa for rotary
@ -1316,7 +1328,13 @@ class AxialSpaceTimeTransformer(Module):
time_attn_kv_caches = []
has_kv_cache = exists(kv_cache)
if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
kv_cache = default(kv_cache, (None,))
iter_kv_cache = iter(kv_cache)
# attention
@ -1362,6 +1380,10 @@ class AxialSpaceTimeTransformer(Module):
out = self.final_norm(tokens)
if has_kv_cache:
# just concat the past tokens back on for now, todo - clean up the logic
out = cat((past_tokens, out), dim = 1)
if not return_kv_cache:
return out
@ -2020,9 +2042,11 @@ class DynamicsWorldModel(Module):
image_width = None,
return_decoded_video = None,
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
use_time_kv_cache = True,
return_rewards_per_frame = False,
return_agent_actions = False,
return_log_probs_and_values = False
return_log_probs_and_values = False,
return_time_kv_cache = False
): # (b t n d) | (b c t h w)
@ -2072,6 +2096,10 @@ class DynamicsWorldModel(Module):
if return_rewards_per_frame:
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
# handle maybe time kv cache
time_kv_cache = None
# while all the frames of the video (per latent) is not generated
while latents.shape[1] < time_steps:
@ -2080,6 +2108,8 @@ class DynamicsWorldModel(Module):
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
for step in range(num_steps):
is_last_step = (step + 1) == num_steps
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
@ -2088,7 +2118,7 @@ class DynamicsWorldModel(Module):
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred, agent_embed = self.forward(
pred, (agent_embed, next_time_kv_cache) = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_levels_with_context,
step_sizes = step_size,
@ -2096,11 +2126,17 @@ class DynamicsWorldModel(Module):
tasks = tasks,
discrete_actions = decoded_discrete_actions,
continuous_actions = decoded_continuous_actions,
time_kv_cache = time_kv_cache,
latent_is_noised = True,
return_pred_only = True,
return_agent_tokens = True
return_intermediates = True,
)
if use_time_kv_cache and is_last_step:
time_kv_cache = next_time_kv_cache
# unpack pred
_, pred = unpack(pred, pack_context_shape, 'b * n d')
# derive flow, based on whether in x-space or not
@ -2186,7 +2222,12 @@ class DynamicsWorldModel(Module):
# only return video or latent if not requesting anything else, for first stage training
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
return video if return_decoded_video else latents
out = video if return_decoded_video else latents
if not return_time_kv_cache:
return out
return out, time_kv_cache
# returning agent actions, rewards, and log probs + values for policy optimization
@ -2209,7 +2250,10 @@ class DynamicsWorldModel(Module):
gen.values = decoded_values
return gen
if not return_time_kv_cache:
return gen
return gen, time_kv_cache
def forward(
self,
@ -2226,10 +2270,11 @@ class DynamicsWorldModel(Module):
continuous_actions = None, # (b t na) | (b t-1 na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
time_kv_cache = None,
return_pred_only = False,
latent_is_noised = False,
return_all_losses = False,
return_agent_tokens = False,
return_intermediates = False,
add_autoregressive_action_loss = False,
update_loss_ema = None
):
@ -2397,6 +2442,9 @@ class DynamicsWorldModel(Module):
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
elif self.action_embedder.has_actions:
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
else:
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
@ -2440,7 +2488,7 @@ class DynamicsWorldModel(Module):
# attention
tokens, time_kv_cache = self.transformer(tokens, return_kv_cache = True)
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
# unpack
@ -2455,7 +2503,10 @@ class DynamicsWorldModel(Module):
if not return_agent_tokens:
return pred
return pred, agent_tokens
if not return_time_kv_cache:
return pred, agent_tokens
return pred, (agent_tokens, next_time_kv_cache)
# curry into get_prediction what does not change during first call as well as the shortcut ones
@ -2463,13 +2514,13 @@ class DynamicsWorldModel(Module):
# forward the network
pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
if return_pred_only:
if not return_agent_tokens:
if not return_intermediates:
return pred
return pred, encoded_agent_tokens
return pred, (encoded_agent_tokens, next_time_kv_cache)
# determine the target for the loss

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.48"
version = "0.0.49"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -12,6 +12,7 @@ import torch
@param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True))
@param('use_time_kv_cache', (False, True))
def test_e2e(
pred_orig_latent,
grouped_query_attn,
@ -22,7 +23,8 @@ def test_e2e(
signal_and_step_passed_in,
condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token
add_reward_embed_to_agent_token,
use_time_kv_cache
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -108,7 +110,8 @@ def test_e2e(
image_height = 128,
image_width = 128,
batch_size = 2,
return_rewards_per_frame = True
return_rewards_per_frame = True,
use_time_kv_cache = use_time_kv_cache
)
assert generations.video.shape == (2, 3, 10, 128, 128)