will apply the golden gate rotary for this work as an option

This commit is contained in:
lucidrains 2025-10-01 10:07:54 -07:00
parent ceb1af263e
commit 882e63511b
2 changed files with 94 additions and 3 deletions

View File

@ -17,3 +17,12 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
url = {https://arxiv.org/abs/2509.24527}, url = {https://arxiv.org/abs/2509.24527},
} }
``` ```
```bibtex
@misc{xiong2025ndrope,
author = {Jerry Xiong},
title = {On n-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}
```

View File

@ -6,10 +6,17 @@ from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, tensor, Tensor, is_tensor from torch import cat, stack, arange, tensor, Tensor, is_tensor
# ein related # ein related
# b - batch
# n - sequence
# h - attention heads
# d - feature dimension
# f - frequencies (rotary)
# p - positions (3 for spacetime in this work)
import einx import einx
from einops import einsum, rearrange, repeat, reduce from einops import einsum, rearrange, repeat, reduce
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
@ -38,12 +45,87 @@ def exists(v):
def default(v, d): def default(v, d):
return v if exists(v) else d return v if exists(v) else d
def divisible_by(num, den):
return (num % den) == 0
def l2norm(t): def l2norm(t):
return F.normalize(t, dim = -1, p = 2) return F.normalize(t, dim = -1, p = 2)
def softclamp(t, value = 50.): def softclamp(t, value = 50.):
return (t / value).tanh() * value return (t / value).tanh() * value
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
def _phi(m):
x = 2.
for _ in range(10):
x = (1. + x) ** (1. / (m + 1.))
return x
def make_directions(n, d):
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGateRoPENd(Module):
def __init__(
self,
dim_pos,
heads,
dim_head,
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0., # proportion of frequencies set to 0
):
super().__init__()
assert divisible_by(dim_head, 2)
n_freqs = dim_head // 2
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = make_directions(heads * n_freqs, dim_pos)
directions = rearrange(directions, '(h f) p -> h f p', h = heads)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
def forward(
self,
x, # (b h n d)
pos # (b n p)
):
dtype = x
x, y = x.float().chunk(2, dim = -1) # (b, h, n, f)
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# thetas for freqs and positions (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
# apply rotations
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
out = cat((x_out, y_out), dim = -1)
return out.type_as(dtype)
# multi-head rmsnorm # multi-head rmsnorm
class MultiHeadRMSNorm(Module): class MultiHeadRMSNorm(Module):
@ -58,7 +140,7 @@ class MultiHeadRMSNorm(Module):
def forward( def forward(
self, self,
x x # (b h n d)
): ):
normed = l2norm(x) normed = l2norm(x)
scale = (self.gamma + 1.) * self.scale scale = (self.gamma + 1.) * self.scale
@ -117,7 +199,7 @@ class Attention(Module):
def forward( def forward(
self, self,
tokens, tokens, # (b n d)
kv_cache = None, kv_cache = None,
return_kv_cache = False return_kv_cache = False
): ):