will apply the golden gate rotary for this work as an option
This commit is contained in:
parent
ceb1af263e
commit
882e63511b
@ -17,3 +17,12 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
||||
url = {https://arxiv.org/abs/2509.24527},
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xiong2025ndrope,
|
||||
author = {Jerry Xiong},
|
||||
title = {On n-dimensional rotary positional embeddings},
|
||||
year = {2025},
|
||||
url = {https://jerryxio.ng/posts/nd-rope/}
|
||||
}
|
||||
```
|
||||
|
||||
@ -6,10 +6,17 @@ from functools import partial
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||
from torch import cat, stack, tensor, Tensor, is_tensor
|
||||
from torch import cat, stack, arange, tensor, Tensor, is_tensor
|
||||
|
||||
# ein related
|
||||
|
||||
# b - batch
|
||||
# n - sequence
|
||||
# h - attention heads
|
||||
# d - feature dimension
|
||||
# f - frequencies (rotary)
|
||||
# p - positions (3 for spacetime in this work)
|
||||
|
||||
import einx
|
||||
from einops import einsum, rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
@ -38,12 +45,87 @@ def exists(v):
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def softclamp(t, value = 50.):
|
||||
return (t / value).tanh() * value
|
||||
|
||||
# golden gate rotary - Jerry Xiong, PhD student at UIUC
|
||||
# https://jerryxio.ng/posts/nd-rope/
|
||||
|
||||
def _phi(m):
|
||||
x = 2.
|
||||
for _ in range(10):
|
||||
x = (1. + x) ** (1. / (m + 1.))
|
||||
return x
|
||||
|
||||
def make_directions(n, d):
|
||||
g = _phi(d)
|
||||
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
|
||||
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
|
||||
z = torch.fmod(i * alpha, 1.0)
|
||||
directions = torch.erfinv(2.0 * z - 1.0)
|
||||
directions = l2norm(directions)
|
||||
return directions.float()
|
||||
|
||||
class GoldenGateRoPENd(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_pos,
|
||||
heads,
|
||||
dim_head,
|
||||
rope_min_freq = 1.,
|
||||
rope_max_freq = 10000.,
|
||||
rope_p_zero_freqs = 0., # proportion of frequencies set to 0
|
||||
):
|
||||
super().__init__()
|
||||
assert divisible_by(dim_head, 2)
|
||||
|
||||
n_freqs = dim_head // 2
|
||||
n_zero_freqs = round(rope_p_zero_freqs * n_freqs)
|
||||
|
||||
omega = cat((
|
||||
torch.zeros(n_zero_freqs),
|
||||
rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
|
||||
))
|
||||
|
||||
directions = make_directions(heads * n_freqs, dim_pos)
|
||||
directions = rearrange(directions, '(h f) p -> h f p', h = heads)
|
||||
|
||||
omega_expanded = rearrange(omega, 'f -> f 1')
|
||||
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x, # (b h n d)
|
||||
pos # (b n p)
|
||||
):
|
||||
dtype = x
|
||||
|
||||
x, y = x.float().chunk(2, dim = -1) # (b, h, n, f)
|
||||
|
||||
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
|
||||
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
|
||||
|
||||
# thetas for freqs and positions (batch, head, seq, freq)
|
||||
|
||||
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
|
||||
|
||||
# apply rotations
|
||||
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
x_out = x * cos_theta - y * sin_theta
|
||||
y_out = x * sin_theta + y * cos_theta
|
||||
|
||||
out = cat((x_out, y_out), dim = -1)
|
||||
return out.type_as(dtype)
|
||||
|
||||
# multi-head rmsnorm
|
||||
|
||||
class MultiHeadRMSNorm(Module):
|
||||
@ -58,7 +140,7 @@ class MultiHeadRMSNorm(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
x # (b h n d)
|
||||
):
|
||||
normed = l2norm(x)
|
||||
scale = (self.gamma + 1.) * self.scale
|
||||
@ -117,7 +199,7 @@ class Attention(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens,
|
||||
tokens, # (b n d)
|
||||
kv_cache = None,
|
||||
return_kv_cache = False
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user