fix rotary embeddings in prsence of kv caching
This commit is contained in:
parent
7195bbb196
commit
11fd2c477c
@ -847,11 +847,12 @@ class Rotary1D(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
seq_len
|
||||
seq_len,
|
||||
offset = 0
|
||||
):
|
||||
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
||||
|
||||
t = torch.arange(seq_len, device = device).type(dtype)
|
||||
t = torch.arange(seq_len, device = device).type(dtype) + offset
|
||||
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
||||
|
||||
return cat((freqs, freqs), dim = -1)
|
||||
@ -867,7 +868,6 @@ def apply_rotations(
|
||||
rotations_seq_len = rotations.shape[-2]
|
||||
|
||||
# handle kv caching with rotations
|
||||
# todo - only fetch rotary embedding for one timestep
|
||||
|
||||
if rotations_seq_len > seq_len:
|
||||
rotations = rotations[-seq_len:]
|
||||
@ -1320,23 +1320,30 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||
|
||||
# rotary
|
||||
|
||||
rotary_pos_emb = self.time_rotary(time)
|
||||
|
||||
# prepare cache
|
||||
|
||||
time_attn_kv_caches = []
|
||||
|
||||
has_kv_cache = exists(kv_cache)
|
||||
|
||||
|
||||
if has_kv_cache:
|
||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||
|
||||
rotary_seq_len = 1
|
||||
rotary_pos_offset = past_tokens.shape[-2]
|
||||
else:
|
||||
rotary_seq_len = time
|
||||
rotary_pos_offset = 0
|
||||
|
||||
kv_cache = default(kv_cache, (None,))
|
||||
|
||||
iter_kv_cache = iter(kv_cache)
|
||||
|
||||
# rotary
|
||||
|
||||
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||
|
||||
# attention
|
||||
|
||||
tokens = self.expand_streams(tokens)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.50"
|
||||
version = "0.0.52"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user