fix time kv cache

This commit is contained in:
lucidrains 2025-10-19 09:16:06 -07:00
parent 4930002e99
commit 27ed6d0ba5
2 changed files with 21 additions and 5 deletions

View File

@ -80,6 +80,7 @@ class Experience:
step_size: int | None = None
agent_index: int = 0
is_from_world_model: bool = True
is_batched: bool = True
# helpers
@ -1312,9 +1313,19 @@ class AxialSpaceTimeTransformer(Module):
layer_rotary_pos_emb = rotary_pos_emb if layer_is_time else None
# maybe past kv cache
maybe_kv_cache = next(iter_kv_cache, None) if layer_is_time else None
# attention layer
tokens, kv_cache = attn(tokens, rotary_pos_emb = layer_rotary_pos_emb, attend_fn = attend_fn, kv_cache = next(iter_kv_cache, None), return_kv_cache = True)
tokens, next_kv_cache = attn(
tokens,
rotary_pos_emb = layer_rotary_pos_emb,
attend_fn = attend_fn,
kv_cache = maybe_kv_cache,
return_kv_cache = True
)
tokens = post_attn_rearrange(tokens)
@ -1325,7 +1336,7 @@ class AxialSpaceTimeTransformer(Module):
# save kv cache if is time layer
if layer_is_time:
time_attn_kv_caches.append(kv_cache)
time_attn_kv_caches.append(next_kv_cache)
tokens = self.reduce_streams(tokens)
@ -2624,7 +2635,12 @@ class DynamicsWorldModel(Module):
class Dreamer(Module):
def __init__(
self,
video_tokenizer: VideoTokenizer,
dynamics_model: DynamicsWorldModel,
state_tokenizer: VideoTokenizer,
world_model: DynamicsWorldModel,
):
super().__init__()
self.state_toke = state_tokenizer
self.world_model = world_model
def interact_with_sim(self, env) -> Experience:
raise NotImplementedError

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.44"
version = "0.0.45"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }