diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index b661a24..652aa02 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -861,7 +861,19 @@ def apply_rotations( rotations, # (h n d) | (n d) t # (b h n d) ): - heads, dtype = t.shape[1], t.dtype + + heads, seq_len, dtype = *t.shape[1:3], t.dtype + + rotations_seq_len = rotations.shape[-2] + + # handle kv caching with rotations + # todo - only fetch rotary embedding for one timestep + + if rotations_seq_len > seq_len: + rotations = rotations[-seq_len:] + + # precision + t = t.float() # handle gqa for rotary @@ -1316,7 +1328,13 @@ class AxialSpaceTimeTransformer(Module): time_attn_kv_caches = [] + has_kv_cache = exists(kv_cache) + + if has_kv_cache: + past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] + kv_cache = default(kv_cache, (None,)) + iter_kv_cache = iter(kv_cache) # attention @@ -1362,6 +1380,10 @@ class AxialSpaceTimeTransformer(Module): out = self.final_norm(tokens) + if has_kv_cache: + # just concat the past tokens back on for now, todo - clean up the logic + out = cat((past_tokens, out), dim = 1) + if not return_kv_cache: return out @@ -2020,9 +2042,11 @@ class DynamicsWorldModel(Module): image_width = None, return_decoded_video = None, context_signal_noise = 0.1, # they do a noising of the past, this was from an old diffusion world modeling paper from EPFL iirc + use_time_kv_cache = True, return_rewards_per_frame = False, return_agent_actions = False, - return_log_probs_and_values = False + return_log_probs_and_values = False, + return_time_kv_cache = False ): # (b t n d) | (b c t h w) @@ -2072,6 +2096,10 @@ class DynamicsWorldModel(Module): if return_rewards_per_frame: decoded_rewards = empty((batch_size, 0), device = self.device, dtype = torch.float32) + # handle maybe time kv cache + + time_kv_cache = None + # while all the frames of the video (per latent) is not generated while latents.shape[1] < time_steps: @@ -2080,6 +2108,8 @@ class DynamicsWorldModel(Module): noised_latent = randn((batch_size, 1, *latent_shape), device = self.device) for step in range(num_steps): + is_last_step = (step + 1) == num_steps + signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device) noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8) @@ -2088,7 +2118,7 @@ class DynamicsWorldModel(Module): signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) - pred, agent_embed = self.forward( + pred, (agent_embed, next_time_kv_cache) = self.forward( latents = noised_latent_with_context, signal_levels = signal_levels_with_context, step_sizes = step_size, @@ -2096,11 +2126,17 @@ class DynamicsWorldModel(Module): tasks = tasks, discrete_actions = decoded_discrete_actions, continuous_actions = decoded_continuous_actions, + time_kv_cache = time_kv_cache, latent_is_noised = True, return_pred_only = True, - return_agent_tokens = True + return_intermediates = True, ) + if use_time_kv_cache and is_last_step: + time_kv_cache = next_time_kv_cache + + # unpack pred + _, pred = unpack(pred, pack_context_shape, 'b * n d') # derive flow, based on whether in x-space or not @@ -2186,7 +2222,12 @@ class DynamicsWorldModel(Module): # only return video or latent if not requesting anything else, for first stage training if not has_at_least_one(return_rewards_per_frame, return_agent_actions): - return video if return_decoded_video else latents + out = video if return_decoded_video else latents + + if not return_time_kv_cache: + return out + + return out, time_kv_cache # returning agent actions, rewards, and log probs + values for policy optimization @@ -2209,7 +2250,10 @@ class DynamicsWorldModel(Module): gen.values = decoded_values - return gen + if not return_time_kv_cache: + return gen + + return gen, time_kv_cache def forward( self, @@ -2226,10 +2270,11 @@ class DynamicsWorldModel(Module): continuous_actions = None, # (b t na) | (b t-1 na) discrete_action_types = None, # (na) continuous_action_types = None, # (na) + time_kv_cache = None, return_pred_only = False, latent_is_noised = False, return_all_losses = False, - return_agent_tokens = False, + return_intermediates = False, add_autoregressive_action_loss = False, update_loss_ema = None ): @@ -2397,6 +2442,9 @@ class DynamicsWorldModel(Module): action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens) + elif self.action_embedder.has_actions: + action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1]) + else: action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens @@ -2440,7 +2488,7 @@ class DynamicsWorldModel(Module): # attention - tokens, time_kv_cache = self.transformer(tokens, return_kv_cache = True) + tokens, next_time_kv_cache = self.transformer(tokens, kv_cache = time_kv_cache, return_kv_cache = True) # unpack @@ -2455,7 +2503,10 @@ class DynamicsWorldModel(Module): if not return_agent_tokens: return pred - return pred, agent_tokens + if not return_time_kv_cache: + return pred, agent_tokens + + return pred, (agent_tokens, next_time_kv_cache) # curry into get_prediction what does not change during first call as well as the shortcut ones @@ -2463,13 +2514,13 @@ class DynamicsWorldModel(Module): # forward the network - pred, encoded_agent_tokens = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True) + pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True) if return_pred_only: - if not return_agent_tokens: + if not return_intermediates: return pred - return pred, encoded_agent_tokens + return pred, (encoded_agent_tokens, next_time_kv_cache) # determine the target for the loss diff --git a/pyproject.toml b/pyproject.toml index 4269618..8dccab7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.48" +version = "0.0.49" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index b13f49f..44b74bf 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -12,6 +12,7 @@ import torch @param('condition_on_actions', (False, True)) @param('num_residual_streams', (1, 4)) @param('add_reward_embed_to_agent_token', (False, True)) +@param('use_time_kv_cache', (False, True)) def test_e2e( pred_orig_latent, grouped_query_attn, @@ -22,7 +23,8 @@ def test_e2e( signal_and_step_passed_in, condition_on_actions, num_residual_streams, - add_reward_embed_to_agent_token + add_reward_embed_to_agent_token, + use_time_kv_cache ): from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel @@ -108,7 +110,8 @@ def test_e2e( image_height = 128, image_width = 128, batch_size = 2, - return_rewards_per_frame = True + return_rewards_per_frame = True, + use_time_kv_cache = use_time_kv_cache ) assert generations.video.shape == (2, 3, 10, 128, 128)