diff --git a/README.md b/README.md index a90ad62..4c89a2b 100644 --- a/README.md +++ b/README.md @@ -17,3 +17,12 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v url = {https://arxiv.org/abs/2509.24527}, } ``` + +```bibtex +@misc{xiong2025ndrope, + author = {Jerry Xiong}, + title = {On n-dimensional rotary positional embeddings}, + year = {2025}, + url = {https://jerryxio.ng/posts/nd-rope/} +} +``` diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index eae8877..9198091 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -6,10 +6,17 @@ from functools import partial import torch import torch.nn.functional as F from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity -from torch import cat, stack, tensor, Tensor, is_tensor +from torch import cat, stack, arange, tensor, Tensor, is_tensor # ein related +# b - batch +# n - sequence +# h - attention heads +# d - feature dimension +# f - frequencies (rotary) +# p - positions (3 for spacetime in this work) + import einx from einops import einsum, rearrange, repeat, reduce from einops.layers.torch import Rearrange @@ -38,12 +45,87 @@ def exists(v): def default(v, d): return v if exists(v) else d +def divisible_by(num, den): + return (num % den) == 0 + def l2norm(t): return F.normalize(t, dim = -1, p = 2) def softclamp(t, value = 50.): return (t / value).tanh() * value +# golden gate rotary - Jerry Xiong, PhD student at UIUC +# https://jerryxio.ng/posts/nd-rope/ + +def _phi(m): + x = 2. + for _ in range(10): + x = (1. + x) ** (1. / (m + 1.)) + return x + +def make_directions(n, d): + g = _phi(d) + alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64) + i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1) + z = torch.fmod(i * alpha, 1.0) + directions = torch.erfinv(2.0 * z - 1.0) + directions = l2norm(directions) + return directions.float() + +class GoldenGateRoPENd(Module): + def __init__( + self, + dim_pos, + heads, + dim_head, + rope_min_freq = 1., + rope_max_freq = 10000., + rope_p_zero_freqs = 0., # proportion of frequencies set to 0 + ): + super().__init__() + assert divisible_by(dim_head, 2) + + n_freqs = dim_head // 2 + n_zero_freqs = round(rope_p_zero_freqs * n_freqs) + + omega = cat(( + torch.zeros(n_zero_freqs), + rope_min_freq * (rope_max_freq / rope_min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs), + )) + + directions = make_directions(heads * n_freqs, dim_pos) + directions = rearrange(directions, '(h f) p -> h f p', h = heads) + + omega_expanded = rearrange(omega, 'f -> f 1') + self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p) + + def forward( + self, + x, # (b h n d) + pos # (b n p) + ): + dtype = x + + x, y = x.float().chunk(2, dim = -1) # (b, h, n, f) + + freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p') + positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p') + + # thetas for freqs and positions (batch, head, seq, freq) + + theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum') + + # apply rotations + + cos_theta = torch.cos(theta) + sin_theta = torch.sin(theta) + + x_out = x * cos_theta - y * sin_theta + y_out = x * sin_theta + y * cos_theta + + out = cat((x_out, y_out), dim = -1) + return out.type_as(dtype) + # multi-head rmsnorm class MultiHeadRMSNorm(Module): @@ -58,7 +140,7 @@ class MultiHeadRMSNorm(Module): def forward( self, - x + x # (b h n d) ): normed = l2norm(x) scale = (self.gamma + 1.) * self.scale @@ -117,7 +199,7 @@ class Attention(Module): def forward( self, - tokens, + tokens, # (b n d) kv_cache = None, return_kv_cache = False ):