From 11fd2c477c8c457fbc026e8ac9d7ebeddf88a8d9 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Oct 2025 07:09:26 -0700 Subject: [PATCH] fix rotary embeddings in prsence of kv caching --- dreamer4/dreamer4.py | 21 ++++++++++++++------- pyproject.toml | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index d839f59..dd11a69 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -847,11 +847,12 @@ class Rotary1D(Module): def forward( self, - seq_len + seq_len, + offset = 0 ): device, dtype = self.inv_freq.device, self.inv_freq.dtype - t = torch.arange(seq_len, device = device).type(dtype) + t = torch.arange(seq_len, device = device).type(dtype) + offset freqs = einsum(t, self.inv_freq, 'i, j -> i j') return cat((freqs, freqs), dim = -1) @@ -867,7 +868,6 @@ def apply_rotations( rotations_seq_len = rotations.shape[-2] # handle kv caching with rotations - # todo - only fetch rotary embedding for one timestep if rotations_seq_len > seq_len: rotations = rotations[-seq_len:] @@ -1320,23 +1320,30 @@ class AxialSpaceTimeTransformer(Module): time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs) - # rotary - - rotary_pos_emb = self.time_rotary(time) - # prepare cache time_attn_kv_caches = [] has_kv_cache = exists(kv_cache) + if has_kv_cache: past_tokens, tokens = tokens[:, :-1], tokens[:, -1:] + rotary_seq_len = 1 + rotary_pos_offset = past_tokens.shape[-2] + else: + rotary_seq_len = time + rotary_pos_offset = 0 + kv_cache = default(kv_cache, (None,)) iter_kv_cache = iter(kv_cache) + # rotary + + rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset) + # attention tokens = self.expand_streams(tokens) diff --git a/pyproject.toml b/pyproject.toml index 71e9c95..1a76c38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.50" +version = "0.0.52" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }