From 27ed6d0ba5220f53b4e19c35c617b044168f7ea8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 19 Oct 2025 09:16:06 -0700 Subject: [PATCH] fix time kv cache --- dreamer4/dreamer4.py | 24 ++++++++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index f0dc400..dca0796 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -80,6 +80,7 @@ class Experience: step_size: int | None = None agent_index: int = 0 is_from_world_model: bool = True + is_batched: bool = True # helpers @@ -1312,9 +1313,19 @@ class AxialSpaceTimeTransformer(Module): layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None + # maybe past kv cache + + maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None + # attention layer - tokens, kv_cache = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn, kv_cache = next(iter_kv_cache, None), return_kv_cache = True) + tokens, next_kv_cache = attn( + tokens, + rotary_pos_emb = layer_rotary_pos_emb, + attend_fn = attend_fn, + kv_cache = maybe_kv_cache, + return_kv_cache = True + ) tokens = post_attn_rearrange(tokens) @@ -1325,7 +1336,7 @@ class AxialSpaceTimeTransformer(Module): # save kv cache if is time layer if layer_is_time: - time_attn_kv_caches.append(kv_cache) + time_attn_kv_caches.append(next_kv_cache) tokens = self.reduce_streams(tokens) @@ -2624,7 +2635,12 @@ class DynamicsWorldModel(Module): class Dreamer(Module): def __init__( self, - video_tokenizer: VideoTokenizer, - dynamics_model: DynamicsWorldModel, + state_tokenizer: VideoTokenizer, + world_model: DynamicsWorldModel, ): super().__init__() + self.state_toke = state_tokenizer + self.world_model = world_model + + def interact_with_sim(self, env) -> Experience: + raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index 77dd976..094eab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.44" +version = "0.0.45" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }