fix rotary embeddings in prsence of kv caching
This commit is contained in:
parent
7195bbb196
commit
11fd2c477c
@ -847,11 +847,12 @@ class Rotary1D(Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
seq_len
|
seq_len,
|
||||||
|
offset = 0
|
||||||
):
|
):
|
||||||
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
device, dtype = self.inv_freq.device, self.inv_freq.dtype
|
||||||
|
|
||||||
t = torch.arange(seq_len, device = device).type(dtype)
|
t = torch.arange(seq_len, device = device).type(dtype) + offset
|
||||||
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
freqs = einsum(t, self.inv_freq, 'i, j -> i j')
|
||||||
|
|
||||||
return cat((freqs, freqs), dim = -1)
|
return cat((freqs, freqs), dim = -1)
|
||||||
@ -867,7 +868,6 @@ def apply_rotations(
|
|||||||
rotations_seq_len = rotations.shape[-2]
|
rotations_seq_len = rotations.shape[-2]
|
||||||
|
|
||||||
# handle kv caching with rotations
|
# handle kv caching with rotations
|
||||||
# todo - only fetch rotary embedding for one timestep
|
|
||||||
|
|
||||||
if rotations_seq_len > seq_len:
|
if rotations_seq_len > seq_len:
|
||||||
rotations = rotations[-seq_len:]
|
rotations = rotations[-seq_len:]
|
||||||
@ -1320,23 +1320,30 @@ class AxialSpaceTimeTransformer(Module):
|
|||||||
|
|
||||||
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
time_attend = get_attend_fn(causal = True, seq_len = time, k_seq_len = time, **attend_kwargs)
|
||||||
|
|
||||||
# rotary
|
|
||||||
|
|
||||||
rotary_pos_emb = self.time_rotary(time)
|
|
||||||
|
|
||||||
# prepare cache
|
# prepare cache
|
||||||
|
|
||||||
time_attn_kv_caches = []
|
time_attn_kv_caches = []
|
||||||
|
|
||||||
has_kv_cache = exists(kv_cache)
|
has_kv_cache = exists(kv_cache)
|
||||||
|
|
||||||
|
|
||||||
if has_kv_cache:
|
if has_kv_cache:
|
||||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||||
|
|
||||||
|
rotary_seq_len = 1
|
||||||
|
rotary_pos_offset = past_tokens.shape[-2]
|
||||||
|
else:
|
||||||
|
rotary_seq_len = time
|
||||||
|
rotary_pos_offset = 0
|
||||||
|
|
||||||
kv_cache = default(kv_cache, (None,))
|
kv_cache = default(kv_cache, (None,))
|
||||||
|
|
||||||
iter_kv_cache = iter(kv_cache)
|
iter_kv_cache = iter(kv_cache)
|
||||||
|
|
||||||
|
# rotary
|
||||||
|
|
||||||
|
rotary_pos_emb = self.time_rotary(rotary_seq_len, offset = rotary_pos_offset)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
tokens = self.expand_streams(tokens)
|
tokens = self.expand_streams(tokens)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.50"
|
version = "0.0.52"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user