From 690ecf07dc7fe04df7b2b51961c1a90d4f604f5c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 11 Nov 2025 17:04:02 -0800 Subject: [PATCH] fix the rnn time caching issue --- dreamer4/__init__.py | 3 +- dreamer4/dreamer4.py | 115 +++++++++++++++++++++++++++--------------- pyproject.toml | 2 +- tests/test_dreamer.py | 12 ++--- 4 files changed, 84 insertions(+), 48 deletions(-) diff --git a/dreamer4/__init__.py b/dreamer4/__init__.py index d7ce984..6077132 100644 --- a/dreamer4/__init__.py +++ b/dreamer4/__init__.py @@ -1,6 +1,7 @@ from dreamer4.dreamer4 import ( VideoTokenizer, - DynamicsWorldModel + DynamicsWorldModel, + AxialSpaceTimeTransformer ) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 4c253cb..9865192 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Callable import math from math import ceil, log2 @@ -76,7 +77,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_ AttentionIntermediates = namedtuple('AttentionIntermediates', ('next_kv_cache', 'normed_inputs')) -TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs')) +TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cache', 'normed_time_inputs', 'normed_space_inputs', 'next_rnn_hiddens')) MaybeTensor = Tensor | None @@ -1453,6 +1454,29 @@ class SwiGLUFeedforward(Module): return self.proj_out(x) +# rnn + +class GRULayer(Module): + def __init__( + self, + dim, + dim_out + ): + super().__init__() + self.norm = nn.RMSNorm(dim) + self.gru = nn.GRU(dim, dim_out, batch_first = True) + + def forward( + self, + x, + prev_hiddens = None + ): + x = self.norm(x) + + x, hiddens = self.gru(x, prev_hiddens) + + return x, hiddens + # axial space time transformer class AxialSpaceTimeTransformer(Module): @@ -1531,7 +1555,7 @@ class AxialSpaceTimeTransformer(Module): hyper_conn(branch = SwiGLUFeedforward(dim = dim, **ff_kwargs)) ])) - rnn_layers.append(hyper_conn(branch = nn.Sequential(nn.RMSNorm(dim), nn.GRU(dim, dim, batch_first = True))) if is_time_block and rnn_time else None) + rnn_layers.append(hyper_conn(branch = GRULayer(dim, dim)) if is_time_block and rnn_time else None) self.layers = ModuleList(layers) self.rnn_layers = ModuleList(rnn_layers) @@ -1557,8 +1581,8 @@ class AxialSpaceTimeTransformer(Module): def forward( self, - tokens, # (b t s d) - kv_cache: Tensor | None = None, # (y 2 b h t d) + tokens, # (b t s d) + cache: TransformerIntermediates | None = None, return_intermediates = False ): # (b t s d) | (y 2 b h t d) @@ -1567,6 +1591,14 @@ class AxialSpaceTimeTransformer(Module): assert tokens.ndim == 4 + # destruct intermediates to cache for attention and rnn respectively + + kv_cache = rnn_prev_hiddens = None + + if exists(cache): + kv_cache = cache.next_kv_cache + rnn_prev_hiddens = cache.next_rnn_hiddens + # attend functions for space and time has_kv_cache = exists(kv_cache) @@ -1581,6 +1613,7 @@ class AxialSpaceTimeTransformer(Module): # prepare cache time_attn_kv_caches = [] + rnn_hiddens = [] if has_kv_cache: past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] @@ -1595,6 +1628,10 @@ class AxialSpaceTimeTransformer(Module): iter_kv_cache = iter(kv_cache) + rnn_prev_hiddens = default(rnn_prev_hiddens, (None,)) + + iter_rnn_prev_hiddens = iter(rnn_prev_hiddens) + # rotary rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset) @@ -1625,10 +1662,12 @@ class AxialSpaceTimeTransformer(Module): tokens, inverse_pack_batch = pack_one(tokens, '* t d') - tokens, rnn_hiddens = maybe_rnn(tokens) # todo, handle rnn cache + tokens, layer_rnn_hiddens = maybe_rnn(tokens, next(iter_rnn_prev_hiddens, None)) # todo, handle rnn cache tokens = inverse_pack_batch(tokens) + rnn_hiddens.append(layer_rnn_hiddens) + # when is a axial time attention block, should be causal attend_fn = time_attend if layer_is_time else space_attend @@ -1685,7 +1724,8 @@ class AxialSpaceTimeTransformer(Module): intermediates = TransformerIntermediates( stack(time_attn_kv_caches), safe_stack(normed_time_attn_inputs), - safe_stack(normed_space_attn_inputs) + safe_stack(normed_space_attn_inputs), + safe_stack(rnn_hiddens) ) return out, intermediates @@ -1717,11 +1757,6 @@ class VideoTokenizer(Module): encoder_add_decor_aux_loss = False, decor_auxx_loss_weight = 0.1, decorr_sample_frac = 0.25, - nd_rotary_kwargs: dict = dict( - rope_min_freq = 1., - rope_max_freq = 10000., - rope_p_zero_freqs = 0. - ), num_residual_streams = 1, ): super().__init__() @@ -1938,7 +1973,7 @@ class VideoTokenizer(Module): # encoder attention - tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs) = self.encoder_transformer(tokens, return_intermediates = True) + tokens, (_, time_attn_normed_inputs, space_attn_normed_inputs, _) = self.encoder_transformer(tokens, return_intermediates = True) # latent bottleneck @@ -2011,15 +2046,15 @@ class DynamicsWorldModel(Module): attn_dim_head = 64, attn_softclamp_value = 50., ff_kwargs: dict = dict(), + use_time_rnn = True, loss_weight_fn: Callable = ramp_weight, - num_future_predictions = 8, # they do multi-token prediction of 8 steps forward prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes add_reward_embed_to_agent_token = False, add_reward_embed_dropout = 0.1, num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, continuous_norm_stats = None, - multi_token_pred_len = 8, + multi_token_pred_len = 8, # they do multi-token prediction of 8 steps forward value_head_mlp_depth = 3, policy_head_mlp_depth = 3, latent_flow_loss_weight = 1., @@ -2229,7 +2264,7 @@ class DynamicsWorldModel(Module): num_special_spatial_tokens = num_agents, time_block_every = time_block_every, final_norm = False, - rnn_time = False, + rnn_time = use_time_rnn, **transformer_kwargs ) @@ -2376,7 +2411,7 @@ class DynamicsWorldModel(Module): step_size = 4, max_timesteps = 16, env_is_vectorized = False, - use_time_kv_cache = True, + use_time_cache = True, store_agent_embed = True, store_old_action_unembeds = True, ): @@ -2415,7 +2450,7 @@ class DynamicsWorldModel(Module): # maybe time kv cache - time_kv_cache = None + time_cache = None step_index = 0 @@ -2424,14 +2459,14 @@ class DynamicsWorldModel(Module): latents = self.video_tokenizer(video, return_latents = True) - _, (agent_embed, next_time_kv_cache) = self.forward( + _, (agent_embed, next_time_cache) = self.forward( latents = latents, signal_levels = self.max_steps - 1, step_sizes = step_size, rewards = rewards, discrete_actions = discrete_actions, continuous_actions = continuous_actions, - time_kv_cache = time_kv_cache, + time_cache = time_cache, latent_is_noised = True, return_pred_only = True, return_intermediates = True @@ -2439,8 +2474,8 @@ class DynamicsWorldModel(Module): # time kv cache - if use_time_kv_cache: - time_kv_cache = next_time_kv_cache + if use_time_cache: + time_cache = next_time_cache # get one agent @@ -2832,13 +2867,13 @@ class DynamicsWorldModel(Module): image_width = None, return_decoded_video = None, context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc - time_kv_cache: Tensor | None = None, - use_time_kv_cache = True, + time_cache: Tensor | None = None, + use_time_cache = True, return_rewards_per_frame = False, return_agent_actions = False, return_log_probs_and_values = False, return_for_policy_optimization = False, - return_time_kv_cache = False, + return_time_cache = False, store_agent_embed = True, store_old_action_unembeds = True @@ -2927,7 +2962,7 @@ class DynamicsWorldModel(Module): # (2) decoding anything off agent embedding (rewards, actions, etc) take_extra_step = ( - use_time_kv_cache or + use_time_cache or return_rewards_per_frame or store_agent_embed or return_agent_actions @@ -2968,7 +3003,7 @@ class DynamicsWorldModel(Module): signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) - pred, (agent_embed, next_time_kv_cache) = self.forward( + pred, (agent_embed, next_time_cache) = self.forward( latents = noised_latent_with_context, signal_levels = signal_levels_with_context, step_sizes = step_size, @@ -2978,15 +3013,15 @@ class DynamicsWorldModel(Module): discrete_actions = decoded_discrete_actions, continuous_actions = decoded_continuous_actions, proprio = noised_proprio_with_context, - time_kv_cache = time_kv_cache, + time_cache = time_cache, latent_is_noised = True, latent_has_view_dim = True, return_pred_only = True, return_intermediates = True, ) - if use_time_kv_cache and is_last_step: - time_kv_cache = next_time_kv_cache + if use_time_cache and is_last_step: + time_cache = next_time_cache # early break if taking an extra step for agent embedding off cleaned latents for decoding @@ -3135,10 +3170,10 @@ class DynamicsWorldModel(Module): if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio): out = video if return_decoded_video else latents - if not return_time_kv_cache: + if not return_time_cache: return out - return out, time_kv_cache + return out, time_cache # returning agent actions, rewards, and log probs + values for policy optimization @@ -3168,10 +3203,10 @@ class DynamicsWorldModel(Module): gen.values = decoded_values - if not return_time_kv_cache: + if not return_time_cache: return gen - return gen, time_kv_cache + return gen, time_cache def forward( self, @@ -3190,7 +3225,7 @@ class DynamicsWorldModel(Module): discrete_action_types = None, # (na) continuous_action_types = None, # (na) proprio = None, # (b t dp) - time_kv_cache = None, + time_cache = None, return_pred_only = False, latent_is_noised = False, return_all_losses = False, @@ -3410,7 +3445,7 @@ class DynamicsWorldModel(Module): # main function, needs to be defined as such for shortcut training - additional calls for consistency loss - def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False): + def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False): # latents to spatial tokens @@ -3463,7 +3498,7 @@ class DynamicsWorldModel(Module): # attention - tokens, (next_time_kv_cache, *_) = self.transformer(tokens, kv_cache = time_kv_cache, return_intermediates = True) + tokens, intermediates = self.transformer(tokens, cache = time_cache, return_intermediates = True) # unpack @@ -3487,10 +3522,10 @@ class DynamicsWorldModel(Module): if not return_agent_tokens: return pred - if not return_time_kv_cache: + if not return_time_cache: return pred, agent_tokens - return pred, (agent_tokens, next_time_kv_cache) + return pred, (agent_tokens, intermediates) # curry into get_prediction what does not change during first call as well as the shortcut ones @@ -3498,13 +3533,13 @@ class DynamicsWorldModel(Module): # forward the network - pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True) + pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True) if return_pred_only: if not return_intermediates: return pred - return pred, (encoded_agent_tokens, next_time_kv_cache) + return pred, (encoded_agent_tokens, intermediates) # pack the predictions to calculate flow for different modalities all at once diff --git a/pyproject.toml b/pyproject.toml index fa10143..72c6d3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.18" +version = "0.1.19" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 92731a9..eab9d36 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -15,7 +15,7 @@ def exists(v): @param('condition_on_actions', (False, True)) @param('num_residual_streams', (1, 4)) @param('add_reward_embed_to_agent_token', (False, True)) -@param('use_time_kv_cache', (False, True)) +@param('use_time_cache', (False, True)) @param('var_len', (False, True)) def test_e2e( pred_orig_latent, @@ -28,7 +28,7 @@ def test_e2e( condition_on_actions, num_residual_streams, add_reward_embed_to_agent_token, - use_time_kv_cache, + use_time_cache, var_len ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel @@ -123,7 +123,7 @@ def test_e2e( image_width = 128, batch_size = 2, return_rewards_per_frame = True, - use_time_kv_cache = use_time_kv_cache + use_time_cache = use_time_cache ) assert generations.video.shape == (2, 3, 10, 128, 128) @@ -617,9 +617,9 @@ def test_cache_generate(): num_residual_streams = 1 ) - generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True) - generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) - generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) + generated, time_cache = dynamics.generate(1, return_time_cache = True) + generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True) + generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True) @param('vectorized', (False, True)) @param('use_pmpo', (False, True))