wire up the time kv cache when interacting with sim / env

This commit is contained in:
lucidrains 2025-10-22 08:39:11 -07:00
parent d82debb7a6
commit 7ecc5d03e8
2 changed files with 16 additions and 3 deletions

View File

@ -1972,7 +1972,8 @@ class DynamicsWorldModel(Module):
seed = None,
agent_index = 0,
step_size = 4,
max_timesteps = 16
max_timesteps = 16,
use_time_kv_cache = True
):
assert exists(self.video_tokenizer)
@ -1990,22 +1991,34 @@ class DynamicsWorldModel(Module):
values = None
latents = None
# maybe time kv cache
time_kv_cache = None
for _ in range(max_timesteps):
latents = self.video_tokenizer(video, return_latents = True)
_, (agent_embed, _) = self.forward(
_, (agent_embed, next_time_kv_cache) = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
rewards = rewards,
discrete_actions = discrete_actions,
continuous_actions = continuous_actions,
time_kv_cache = time_kv_cache,
latent_is_noised = True,
return_pred_only = True,
return_intermediates = True
)
# time kv cache
if use_time_kv_cache:
time_kv_cache = next_time_kv_cache
# get one agent
one_agent_embed = agent_embed[..., -1:, agent_index, :]
policy_embed = self.policy_head(one_agent_embed)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.61"
version = "0.0.62"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }