diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 41809cd..d2265e3 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1972,7 +1972,8 @@ class DynamicsWorldModel(Module): seed = None, agent_index = 0, step_size = 4, - max_timesteps = 16 + max_timesteps = 16, + use_time_kv_cache = True ): assert exists(self.video_tokenizer) @@ -1990,22 +1991,34 @@ class DynamicsWorldModel(Module): values = None latents = None + # maybe time kv cache + + time_kv_cache = None + for _ in range(max_timesteps): latents = self.video_tokenizer(video, return_latents = True) - _, (agent_embed, _) = self.forward( + _, (agent_embed, next_time_kv_cache) = self.forward( latents = latents, signal_levels = self.max_steps - 1, step_sizes = step_size, rewards = rewards, discrete_actions = discrete_actions, continuous_actions = continuous_actions, + time_kv_cache = time_kv_cache, latent_is_noised = True, return_pred_only = True, return_intermediates = True ) + # time kv cache + + if use_time_kv_cache: + time_kv_cache = next_time_kv_cache + + # get one agent + one_agent_embed = agent_embed[..., -1:, agent_index, :] policy_embed = self.policy_head(one_agent_embed) diff --git a/pyproject.toml b/pyproject.toml index 8483df6..a72e080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.61" +version = "0.0.62" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }