fix the rnn time caching issue

This commit is contained in:
lucidrains 2025-11-11 17:04:02 -08:00
parent ac1c12f743
commit 690ecf07dc
4 changed files with 84 additions and 48 deletions

View File

@ -1,6 +1,7 @@
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel
DynamicsWorldModel,
AxialSpaceTimeTransformer
)

View File

@ -1,4 +1,5 @@
from __future__ import annotations
from typing import Callable
import math
from math import ceil, log2
@ -76,7 +77,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_
AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs'))
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs'))
TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens'))
MaybeTensor = Tensor | None
@ -1453,6 +1454,29 @@ class SwiGLUFeedforward(Module):
return self.proj_out(x)
# rnn
class GRULayer(Module):
def __init__(
self,
dim,
dim_out
):
super().__init__()
self.norm = nn.RMSNorm(dim)
self.gru = nn.GRU(dim, dim_out, batch_first = True)
def forward(
self,
x,
prev_hiddens = None
):
x = self.norm(x)
x, hiddens = self.gru(x, prev_hiddens)
return x, hiddens
# axial space time transformer
class AxialSpaceTimeTransformer(Module):
@ -1531,7 +1555,7 @@ class AxialSpaceTimeTransformer(Module):
hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs))
]))
rnn_layers.append(hyper_conn(branch = nn.Sequential(nn.RMSNorm(dim), nn.GRU(dim, dim, batch_first = True))) if is_time_block and rnn_time else None)
rnn_layers.append(hyper_conn(branch = GRULayer(dim, dim)) if is_time_block and rnn_time else None)
self.layers = ModuleList(layers)
self.rnn_layers = ModuleList(rnn_layers)
@ -1557,8 +1581,8 @@ class AxialSpaceTimeTransformer(Module):
def forward(
self,
tokens, # (b t s d)
kv_cache: Tensor | None = None, # (y 2 b h t d)
tokens, # (b t s d)
cache: TransformerIntermediates | None = None,
return_intermediates = False
): # (b t s d) | (y 2 b h t d)
@ -1567,6 +1591,14 @@ class AxialSpaceTimeTransformer(Module):
assert tokens.ndim == 4
# destruct intermediates to cache for attention and rnn respectively
kv_cache = rnn_prev_hiddens = None
if exists(cache):
kv_cache = cache.next_kv_cache
rnn_prev_hiddens = cache.next_rnn_hiddens
# attend functions for space and time
has_kv_cache = exists(kv_cache)
@ -1581,6 +1613,7 @@ class AxialSpaceTimeTransformer(Module):
# prepare cache
time_attn_kv_caches = []
rnn_hiddens = []
if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
@ -1595,6 +1628,10 @@ class AxialSpaceTimeTransformer(Module):
iter_kv_cache = iter(kv_cache)
rnn_prev_hiddens = default(rnn_prev_hiddens, (None,))
iter_rnn_prev_hiddens = iter(rnn_prev_hiddens)
# rotary
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
@ -1625,10 +1662,12 @@ class AxialSpaceTimeTransformer(Module):
tokens, inverse_pack_batch = pack_one(tokens, '* t d')
tokens, rnn_hiddens = maybe_rnn(tokens) # todo, handle rnn cache
tokens, layer_rnn_hiddens = maybe_rnn(tokens, next(iter_rnn_prev_hiddens, None)) # todo, handle rnn cache
tokens = inverse_pack_batch(tokens)
rnn_hiddens.append(layer_rnn_hiddens)
# when is a axial time attention block, should be causal
attend_fn = time_attend if layer_is_time else space_attend
@ -1685,7 +1724,8 @@ class AxialSpaceTimeTransformer(Module):
intermediates = TransformerIntermediates(
stack(time_attn_kv_caches),
safe_stack(normed_time_attn_inputs),
safe_stack(normed_space_attn_inputs)
safe_stack(normed_space_attn_inputs),
safe_stack(rnn_hiddens)
)
return out, intermediates
@ -1717,11 +1757,6 @@ class VideoTokenizer(Module):
encoder_add_decor_aux_loss = False,
decor_auxx_loss_weight = 0.1,
decorr_sample_frac = 0.25,
nd_rotary_kwargs: dict = dict(
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0.
),
num_residual_streams = 1,
):
super().__init__()
@ -1938,7 +1973,7 @@ class VideoTokenizer(Module):
# encoder attention
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True)
tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs, _) = self.encoder_transformer(tokens, return_intermediates = True)
# latent bottleneck
@ -2011,15 +2046,15 @@ class DynamicsWorldModel(Module):
attn_dim_head = 64,
attn_softclamp_value = 50.,
ff_kwargs: dict = dict(),
use_time_rnn = True,
loss_weight_fn: Callable = ramp_weight,
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
add_reward_embed_to_agent_token = False,
add_reward_embed_dropout = 0.1,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
continuous_norm_stats = None,
multi_token_pred_len = 8,
multi_token_pred_len = 8, # they do multi-token prediction of 8 steps forward
value_head_mlp_depth = 3,
policy_head_mlp_depth = 3,
latent_flow_loss_weight = 1.,
@ -2229,7 +2264,7 @@ class DynamicsWorldModel(Module):
num_special_spatial_tokens = num_agents,
time_block_every = time_block_every,
final_norm = False,
rnn_time = False,
rnn_time = use_time_rnn,
**transformer_kwargs
)
@ -2376,7 +2411,7 @@ class DynamicsWorldModel(Module):
step_size = 4,
max_timesteps = 16,
env_is_vectorized = False,
use_time_kv_cache = True,
use_time_cache = True,
store_agent_embed = True,
store_old_action_unembeds = True,
):
@ -2415,7 +2450,7 @@ class DynamicsWorldModel(Module):
# maybe time kv cache
time_kv_cache = None
time_cache = None
step_index = 0
@ -2424,14 +2459,14 @@ class DynamicsWorldModel(Module):
latents = self.video_tokenizer(video, return_latents = True)
_, (agent_embed, next_time_kv_cache) = self.forward(
_, (agent_embed, next_time_cache) = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
rewards = rewards,
discrete_actions = discrete_actions,
continuous_actions = continuous_actions,
time_kv_cache = time_kv_cache,
time_cache = time_cache,
latent_is_noised = True,
return_pred_only = True,
return_intermediates = True
@ -2439,8 +2474,8 @@ class DynamicsWorldModel(Module):
# time kv cache
if use_time_kv_cache:
time_kv_cache = next_time_kv_cache
if use_time_cache:
time_cache = next_time_cache
# get one agent
@ -2832,13 +2867,13 @@ class DynamicsWorldModel(Module):
image_width = None,
return_decoded_video = None,
context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc
time_kv_cache: Tensor | None = None,
use_time_kv_cache = True,
time_cache: Tensor | None = None,
use_time_cache = True,
return_rewards_per_frame = False,
return_agent_actions = False,
return_log_probs_and_values = False,
return_for_policy_optimization = False,
return_time_kv_cache = False,
return_time_cache = False,
store_agent_embed = True,
store_old_action_unembeds = True
@ -2927,7 +2962,7 @@ class DynamicsWorldModel(Module):
# (2) decoding anything off agent embedding (rewards, actions, etc)
take_extra_step = (
use_time_kv_cache or
use_time_cache or
return_rewards_per_frame or
store_agent_embed or
return_agent_actions
@ -2968,7 +3003,7 @@ class DynamicsWorldModel(Module):
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred, (agent_embed, next_time_kv_cache) = self.forward(
pred, (agent_embed, next_time_cache) = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_levels_with_context,
step_sizes = step_size,
@ -2978,15 +3013,15 @@ class DynamicsWorldModel(Module):
discrete_actions = decoded_discrete_actions,
continuous_actions = decoded_continuous_actions,
proprio = noised_proprio_with_context,
time_kv_cache = time_kv_cache,
time_cache = time_cache,
latent_is_noised = True,
latent_has_view_dim = True,
return_pred_only = True,
return_intermediates = True,
)
if use_time_kv_cache and is_last_step:
time_kv_cache = next_time_kv_cache
if use_time_cache and is_last_step:
time_cache = next_time_cache
# early break if taking an extra step for agent embedding off cleaned latents for decoding
@ -3135,10 +3170,10 @@ class DynamicsWorldModel(Module):
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
out = video if return_decoded_video else latents
if not return_time_kv_cache:
if not return_time_cache:
return out
return out, time_kv_cache
return out, time_cache
# returning agent actions, rewards, and log probs + values for policy optimization
@ -3168,10 +3203,10 @@ class DynamicsWorldModel(Module):
gen.values = decoded_values
if not return_time_kv_cache:
if not return_time_cache:
return gen
return gen, time_kv_cache
return gen, time_cache
def forward(
self,
@ -3190,7 +3225,7 @@ class DynamicsWorldModel(Module):
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
proprio = None, # (b t dp)
time_kv_cache = None,
time_cache = None,
return_pred_only = False,
latent_is_noised = False,
return_all_losses = False,
@ -3410,7 +3445,7 @@ class DynamicsWorldModel(Module):
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
# latents to spatial tokens
@ -3463,7 +3498,7 @@ class DynamicsWorldModel(Module):
# attention
tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True)
tokens, intermediates = self.transformer(tokens, cache = time_cache, return_intermediates = True)
# unpack
@ -3487,10 +3522,10 @@ class DynamicsWorldModel(Module):
if not return_agent_tokens:
return pred
if not return_time_kv_cache:
if not return_time_cache:
return pred, agent_tokens
return pred, (agent_tokens, next_time_kv_cache)
return pred, (agent_tokens, intermediates)
# curry into get_prediction what does not change during first call as well as the shortcut ones
@ -3498,13 +3533,13 @@ class DynamicsWorldModel(Module):
# forward the network
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
if return_pred_only:
if not return_intermediates:
return pred
return pred, (encoded_agent_tokens, next_time_kv_cache)
return pred, (encoded_agent_tokens, intermediates)
# pack the predictions to calculate flow for different modalities all at once

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.18"
version = "0.1.19"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -15,7 +15,7 @@ def exists(v):
@param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True))
@param('use_time_kv_cache', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True))
def test_e2e(
pred_orig_latent,
@ -28,7 +28,7 @@ def test_e2e(
condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token,
use_time_kv_cache,
use_time_cache,
var_len
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -123,7 +123,7 @@ def test_e2e(
image_width = 128,
batch_size = 2,
return_rewards_per_frame = True,
use_time_kv_cache = use_time_kv_cache
use_time_cache = use_time_cache
)
assert generations.video.shape == (2, 3, 10, 128, 128)
@ -617,9 +617,9 @@ def test_cache_generate():
num_residual_streams = 1
)
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_cache = dynamics.generate(1, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
@param('vectorized', (False, True))
@param('use_pmpo', (False, True))