will apply the golden gate rotary for this work as an option
This commit is contained in:
parent
ceb1af263e
commit
882e63511b
@ -17,3 +17,12 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
|||||||
url = {https://arxiv.org/abs/2509.24527},
|
url = {https://arxiv.org/abs/2509.24527},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{xiong2025ndrope,
|
||||||
|
author = {Jerry Xiong},
|
||||||
|
title = {On n-dimensional rotary positional embeddings},
|
||||||
|
year = {2025},
|
||||||
|
url = {https://jerryxio.ng/posts/nd-rope/}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@ -6,10 +6,17 @@ from functools import partial
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import cat, stack, tensor, Tensor, is_tensor
|
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
|
|
||||||
|
# b - batch
|
||||||
|
# n - sequence
|
||||||
|
# h - attention heads
|
||||||
|
# d - feature dimension
|
||||||
|
# f - frequencies (rotary)
|
||||||
|
# p - positions (3 for spacetime in this work)
|
||||||
|
|
||||||
import einx
|
import einx
|
||||||
from einops import einsum, rearrange, repeat, reduce
|
from einops import einsum, rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
@ -38,12 +45,87 @@ def exists(v):
|
|||||||
def default(v, d):
|
def default(v, d):
|
||||||
return v if exists(v) else d
|
return v if exists(v) else d
|
||||||
|
|
||||||
|
def divisible_by(num, den):
|
||||||
|
return (num % den) == 0
|
||||||
|
|
||||||
def l2norm(t):
|
def l2norm(t):
|
||||||
return F.normalize(t, dim = -1, p = 2)
|
return F.normalize(t, dim = -1, p = 2)
|
||||||
|
|
||||||
def softclamp(t, value = 50.):
|
def softclamp(t, value = 50.):
|
||||||
return (t / value).tanh() * value
|
return (t / value).tanh() * value
|
||||||
|
|
||||||
|
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||||
|
# https://jerryxio.ng/posts/nd-rope/
|
||||||
|
|
||||||
|
def _phi(m):
|
||||||
|
x = 2.
|
||||||
|
for _ in range(10):
|
||||||
|
x = (1. + x) ** (1. / (m + 1.))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def make_directions(n, d):
|
||||||
|
g = _phi(d)
|
||||||
|
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||||
|
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||||
|
z = torch.fmod(i * alpha, 1.0)
|
||||||
|
directions = torch.erfinv(2.0 * z - 1.0)
|
||||||
|
directions = l2norm(directions)
|
||||||
|
return directions.float()
|
||||||
|
|
||||||
|
class GoldenGateRoPENd(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_pos,
|
||||||
|
heads,
|
||||||
|
dim_head,
|
||||||
|
rope_min_freq = 1.,
|
||||||
|
rope_max_freq = 10000.,
|
||||||
|
rope_p_zero_freqs = 0., # proportion of frequencies set to 0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert divisible_by(dim_head, 2)
|
||||||
|
|
||||||
|
n_freqs = dim_head // 2
|
||||||
|
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
||||||
|
|
||||||
|
omega = cat((
|
||||||
|
torch.zeros(n_zero_freqs),
|
||||||
|
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||||
|
))
|
||||||
|
|
||||||
|
directions = make_directions(heads * n_freqs, dim_pos)
|
||||||
|
directions = rearrange(directions, '(h f) p -> h f p', h = heads)
|
||||||
|
|
||||||
|
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||||
|
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x, # (b h n d)
|
||||||
|
pos # (b n p)
|
||||||
|
):
|
||||||
|
dtype = x
|
||||||
|
|
||||||
|
x, y = x.float().chunk(2, dim = -1) # (b, h, n, f)
|
||||||
|
|
||||||
|
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||||
|
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||||
|
|
||||||
|
# thetas for freqs and positions (batch, head, seq, freq)
|
||||||
|
|
||||||
|
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||||
|
|
||||||
|
# apply rotations
|
||||||
|
|
||||||
|
cos_theta = torch.cos(theta)
|
||||||
|
sin_theta = torch.sin(theta)
|
||||||
|
|
||||||
|
x_out = x * cos_theta - y * sin_theta
|
||||||
|
y_out = x * sin_theta + y * cos_theta
|
||||||
|
|
||||||
|
out = cat((x_out, y_out), dim = -1)
|
||||||
|
return out.type_as(dtype)
|
||||||
|
|
||||||
# multi-head rmsnorm
|
# multi-head rmsnorm
|
||||||
|
|
||||||
class MultiHeadRMSNorm(Module):
|
class MultiHeadRMSNorm(Module):
|
||||||
@ -58,7 +140,7 @@ class MultiHeadRMSNorm(Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x
|
x # (b h n d)
|
||||||
):
|
):
|
||||||
normed = l2norm(x)
|
normed = l2norm(x)
|
||||||
scale = (self.gamma + 1.) * self.scale
|
scale = (self.gamma + 1.) * self.scale
|
||||||
@ -117,7 +199,7 @@ class Attention(Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
tokens,
|
tokens, # (b n d)
|
||||||
kv_cache = None,
|
kv_cache = None,
|
||||||
return_kv_cache = False
|
return_kv_cache = False
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user