first pass through the kv cache for the time block in the dynamics model
This commit is contained in:
parent
a7e0c395c3
commit
ca244a290c
@ -861,7 +861,19 @@ def apply_rotations(
|
|||||||
rotations, # (h n d) | (n d)
|
rotations, # (h n d) | (n d)
|
||||||
t # (b h n d)
|
t # (b h n d)
|
||||||
):
|
):
|
||||||
heads, dtype = t.shape[1], t.dtype
|
|
||||||
|
heads, seq_len, dtype = *t.shape[1:3], t.dtype
|
||||||
|
|
||||||
|
rotations_seq_len = rotations.shape[-2]
|
||||||
|
|
||||||
|
# handle kv caching with rotations
|
||||||
|
# todo - only fetch rotary embedding for one timestep
|
||||||
|
|
||||||
|
if rotations_seq_len > seq_len:
|
||||||
|
rotations = rotations[-seq_len:]
|
||||||
|
|
||||||
|
# precision
|
||||||
|
|
||||||
t = t.float()
|
t = t.float()
|
||||||
|
|
||||||
# handle gqa for rotary
|
# handle gqa for rotary
|
||||||
@ -1316,7 +1328,13 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
time_attn_kv_caches = []
|
time_attn_kv_caches = []
|
||||||
|
|
||||||
|
has_kv_cache = exists(kv_cache)
|
||||||
|
|
||||||
|
if has_kv_cache:
|
||||||
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||||
|
|
||||||
kv_cache = default(kv_cache, (None,))
|
kv_cache = default(kv_cache, (None,))
|
||||||
|
|
||||||
iter_kv_cache = iter(kv_cache)
|
iter_kv_cache = iter(kv_cache)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
@ -1362,6 +1380,10 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
out = self.final_norm(tokens)
|
out = self.final_norm(tokens)
|
||||||
|
|
||||||
|
if has_kv_cache:
|
||||||
|
# just concat the past tokens back on for now, todo - clean up the logic
|
||||||
|
out = cat((past_tokens, out), dim = 1)
|
||||||
|
|
||||||
if not return_kv_cache:
|
if not return_kv_cache:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -2020,9 +2042,11 @@ class DynamicsWorldModel(Module):
|
|||||||
image_width = None,
|
image_width = None,
|
||||||
return_decoded_video = None,
|
return_decoded_video = None,
|
||||||
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
|
||||||
|
use_time_kv_cache = True,
|
||||||
return_rewards_per_frame = False,
|
return_rewards_per_frame = False,
|
||||||
return_agent_actions = False,
|
return_agent_actions = False,
|
||||||
return_log_probs_and_values = False
|
return_log_probs_and_values = False,
|
||||||
|
return_time_kv_cache = False
|
||||||
|
|
||||||
): # (b t n d) | (b c t h w)
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
@ -2072,6 +2096,10 @@ class DynamicsWorldModel(Module):
|
|||||||
if return_rewards_per_frame:
|
if return_rewards_per_frame:
|
||||||
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32)
|
||||||
|
|
||||||
|
# handle maybe time kv cache
|
||||||
|
|
||||||
|
time_kv_cache = None
|
||||||
|
|
||||||
# while all the frames of the video (per latent) is not generated
|
# while all the frames of the video (per latent) is not generated
|
||||||
|
|
||||||
while latents.shape[1] < time_steps:
|
while latents.shape[1] < time_steps:
|
||||||
@ -2080,6 +2108,8 @@ class DynamicsWorldModel(Module):
|
|||||||
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
||||||
|
|
||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
|
is_last_step = (step + 1) == num_steps
|
||||||
|
|
||||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||||
|
|
||||||
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||||
@ -2088,7 +2118,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||||
|
|
||||||
pred, agent_embed = self.forward(
|
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
||||||
latents = noised_latent_with_context,
|
latents = noised_latent_with_context,
|
||||||
signal_levels = signal_levels_with_context,
|
signal_levels = signal_levels_with_context,
|
||||||
step_sizes = step_size,
|
step_sizes = step_size,
|
||||||
@ -2096,11 +2126,17 @@ class DynamicsWorldModel(Module):
|
|||||||
tasks = tasks,
|
tasks = tasks,
|
||||||
discrete_actions = decoded_discrete_actions,
|
discrete_actions = decoded_discrete_actions,
|
||||||
continuous_actions = decoded_continuous_actions,
|
continuous_actions = decoded_continuous_actions,
|
||||||
|
time_kv_cache = time_kv_cache,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
return_pred_only = True,
|
return_pred_only = True,
|
||||||
return_agent_tokens = True
|
return_intermediates = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_time_kv_cache and is_last_step:
|
||||||
|
time_kv_cache = next_time_kv_cache
|
||||||
|
|
||||||
|
# unpack pred
|
||||||
|
|
||||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||||
|
|
||||||
# derive flow, based on whether in x-space or not
|
# derive flow, based on whether in x-space or not
|
||||||
@ -2186,7 +2222,12 @@ class DynamicsWorldModel(Module):
|
|||||||
# only return video or latent if not requesting anything else, for first stage training
|
# only return video or latent if not requesting anything else, for first stage training
|
||||||
|
|
||||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
||||||
return video if return_decoded_video else latents
|
out = video if return_decoded_video else latents
|
||||||
|
|
||||||
|
if not return_time_kv_cache:
|
||||||
|
return out
|
||||||
|
|
||||||
|
return out, time_kv_cache
|
||||||
|
|
||||||
# returning agent actions, rewards, and log probs + values for policy optimization
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
||||||
|
|
||||||
@ -2209,7 +2250,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
gen.values = decoded_values
|
gen.values = decoded_values
|
||||||
|
|
||||||
return gen
|
if not return_time_kv_cache:
|
||||||
|
return gen
|
||||||
|
|
||||||
|
return gen, time_kv_cache
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -2226,10 +2270,11 @@ class DynamicsWorldModel(Module):
|
|||||||
continuous_actions = None, # (b t na) | (b t-1 na)
|
continuous_actions = None, # (b t na) | (b t-1 na)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
|
time_kv_cache = None,
|
||||||
return_pred_only = False,
|
return_pred_only = False,
|
||||||
latent_is_noised = False,
|
latent_is_noised = False,
|
||||||
return_all_losses = False,
|
return_all_losses = False,
|
||||||
return_agent_tokens = False,
|
return_intermediates = False,
|
||||||
add_autoregressive_action_loss = False,
|
add_autoregressive_action_loss = False,
|
||||||
update_loss_ema = None
|
update_loss_ema = None
|
||||||
):
|
):
|
||||||
@ -2397,6 +2442,9 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
|
action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens)
|
||||||
|
|
||||||
|
elif self.action_embedder.has_actions:
|
||||||
|
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||||
|
|
||||||
@ -2440,7 +2488,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
tokens, time_kv_cache = self.transformer(tokens, return_kv_cache = True)
|
tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True)
|
||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
@ -2455,7 +2503,10 @@ class DynamicsWorldModel(Module):
|
|||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
return pred, agent_tokens
|
if not return_time_kv_cache:
|
||||||
|
return pred, agent_tokens
|
||||||
|
|
||||||
|
return pred, (agent_tokens, next_time_kv_cache)
|
||||||
|
|
||||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||||
|
|
||||||
@ -2463,13 +2514,13 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True)
|
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
if not return_agent_tokens:
|
if not return_intermediates:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
return pred, encoded_agent_tokens
|
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
||||||
|
|
||||||
# determine the target for the loss
|
# determine the target for the loss
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.48"
|
version = "0.0.49"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import torch
|
|||||||
@param('condition_on_actions', (False, True))
|
@param('condition_on_actions', (False, True))
|
||||||
@param('num_residual_streams', (1, 4))
|
@param('num_residual_streams', (1, 4))
|
||||||
@param('add_reward_embed_to_agent_token', (False, True))
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
|
@param('use_time_kv_cache', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
grouped_query_attn,
|
grouped_query_attn,
|
||||||
@ -22,7 +23,8 @@ def test_e2e(
|
|||||||
signal_and_step_passed_in,
|
signal_and_step_passed_in,
|
||||||
condition_on_actions,
|
condition_on_actions,
|
||||||
num_residual_streams,
|
num_residual_streams,
|
||||||
add_reward_embed_to_agent_token
|
add_reward_embed_to_agent_token,
|
||||||
|
use_time_kv_cache
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
@ -108,7 +110,8 @@ def test_e2e(
|
|||||||
image_height = 128,
|
image_height = 128,
|
||||||
image_width = 128,
|
image_width = 128,
|
||||||
batch_size = 2,
|
batch_size = 2,
|
||||||
return_rewards_per_frame = True
|
return_rewards_per_frame = True,
|
||||||
|
use_time_kv_cache = use_time_kv_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user