will apply the golden gate rotary for this work as an option

This commit is contained in:
lucidrains 2025-10-01 10:07:54 -07:00
parent ceb1af263e
commit 882e63511b
2 changed files with 94 additions and 3 deletions

View File

@ -17,3 +17,12 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
url = {https://arxiv.org/abs/2509.24527},
}
```
```bibtex
@misc{xiong2025ndrope,
author = {Jerry Xiong},
title = {On n-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}
```

View File

@ -6,10 +6,17 @@ from functools import partial
import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, tensor, Tensor, is_tensor
from torch import cat, stack, arange, tensor, Tensor, is_tensor
# ein related
# b - batch
# n - sequence
# h - attention heads
# d - feature dimension
# f - frequencies (rotary)
# p - positions (3 for spacetime in this work)
import einx
from einops import einsum, rearrange, repeat, reduce
from einops.layers.torch import Rearrange
@ -38,12 +45,87 @@ def exists(v):
def default(v, d):
return v if exists(v) else d
def divisible_by(num, den):
return (num % den) == 0
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def softclamp(t, value = 50.):
return (t / value).tanh() * value
# golden gate rotary - Jerry Xiong, PhD student at UIUC
# https://jerryxio.ng/posts/nd-rope/
def _phi(m):
x = 2.
for _ in range(10):
x = (1. + x) ** (1. / (m + 1.))
return x
def make_directions(n, d):
g = _phi(d)
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
z = torch.fmod(i * alpha, 1.0)
directions = torch.erfinv(2.0 * z - 1.0)
directions = l2norm(directions)
return directions.float()
class GoldenGateRoPENd(Module):
def __init__(
self,
dim_pos,
heads,
dim_head,
rope_min_freq = 1.,
rope_max_freq = 10000.,
rope_p_zero_freqs = 0., # proportion of frequencies set to 0
):
super().__init__()
assert divisible_by(dim_head, 2)
n_freqs = dim_head // 2
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
omega = cat((
torch.zeros(n_zero_freqs),
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
))
directions = make_directions(heads * n_freqs, dim_pos)
directions = rearrange(directions, '(h f) p -> h f p', h = heads)
omega_expanded = rearrange(omega, 'f -> f 1')
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
def forward(
self,
x, # (b h n d)
pos # (b n p)
):
dtype = x
x, y = x.float().chunk(2, dim = -1) # (b, h, n, f)
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
# thetas for freqs and positions (batch, head, seq, freq)
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
# apply rotations
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
x_out = x * cos_theta - y * sin_theta
y_out = x * sin_theta + y * cos_theta
out = cat((x_out, y_out), dim = -1)
return out.type_as(dtype)
# multi-head rmsnorm
class MultiHeadRMSNorm(Module):
@ -58,7 +140,7 @@ class MultiHeadRMSNorm(Module):
def forward(
self,
x
x # (b h n d)
):
normed = l2norm(x)
scale = (self.gamma + 1.) * self.scale
@ -117,7 +199,7 @@ class Attention(Module):
def forward(
self,
tokens,
tokens, # (b n d)
kv_cache = None,
return_kv_cache = False
):