fix rotary embeddings in prsence of kv caching

This commit is contained in:
lucidrains 2025-10-21 07:09:26 -07:00
parent 7195bbb196
commit 11fd2c477c
2 changed files with 15 additions and 8 deletions

View File

@ -847,11 +847,12 @@ class Rotary1D(Module):
def forward(
self,
seq_len
seq_len,
offset = 0
):
device, dtype = self.inv_freq.device, self.inv_freq.dtype
t = torch.arange(seq_len, device = device).type(dtype)
t = torch.arange(seq_len, device = device).type(dtype) + offset
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
return cat((freqs, freqs), dim = -1)
@ -867,7 +868,6 @@ def apply_rotations(
rotations_seq_len = rotations.shape[-2]
# handle kv caching with rotations
# todo - only fetch rotary embedding for one timestep
if rotations_seq_len > seq_len:
rotations = rotations[-seq_len:]
@ -1320,23 +1320,30 @@ class AxialSpaceTimeTransformer(Module):
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
# rotary
rotary_pos_emb = self.time_rotary(time)
# prepare cache
time_attn_kv_caches = []
has_kv_cache = exists(kv_cache)
if has_kv_cache:
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
rotary_seq_len = 1
rotary_pos_offset = past_tokens.shape[-2]
else:
rotary_seq_len = time
rotary_pos_offset = 0
kv_cache = default(kv_cache, (None,))
iter_kv_cache = iter(kv_cache)
# rotary
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
# attention
tokens = self.expand_streams(tokens)

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.50"
version = "0.0.52"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }