they employ two stability measures, qk rmsnorm and softclamping of attention logits

This commit is contained in:
lucidrains 2025-10-01 09:40:24 -07:00
parent e8678364ba
commit 2e92c0121a
2 changed files with 62 additions and 2 deletions

View File

@ -5,11 +5,12 @@ from functools import partial
import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
from torch import cat, stack, tensor, Tensor, is_tensor
# ein related
import einx
from einops import einsum, rearrange, repeat, reduce
from einops.layers.torch import Rearrange
@ -25,6 +26,32 @@ def exists(v):
def default(v, d):
return v if exists(v) else d
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
def softclamp(t, value = 50.):
return (t / value).tanh() * value
# multi-head rmsnorm
class MultiHeadRMSNorm(Module):
def __init__(
self,
dim_head,
heads = 8
):
super().__init__()
self.scale = dim_head ** 0.5
self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly
def forward(
self,
x
):
normed = l2norm(x)
scale = (self.gamma + 1.) * self.scale
return einx.multiply('... h n d, h d', normed, scale)
# attention
class Attention(Module):
@ -33,6 +60,7 @@ class Attention(Module):
dim,
dim_head = 64,
heads = 8,
softclamp_value = 50.,
pre_rmsnorm = True
):
super().__init__()
@ -47,6 +75,13 @@ class Attention(Module):
self.to_kv = LinearNoBias(dim, dim_inner * 2)
self.to_out = LinearNoBias(dim_inner, dim)
# stability related
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
self.softclamp_value = softclamp_value
def forward(
self,
tokens,
@ -57,23 +92,47 @@ class Attention(Module):
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
# split heads
q, k, v = map(self.split_heads, (q, k, v))
# qk rmsnorm
q = self.q_heads_rmsnorm(q)
k = self.k_heads_rmsnorm(k)
# caching
if exists(kv_cache):
ck, cv = kv_cache
k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)
q = q * self.scale
# similarity
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
# softclamping a la gemma 3
if exists(self.softclamp_value):
sim = softclamp(sim, self.softclamp_value)
# scale and attention
sim = sim * self.scale
attn = sim.softmax(dim = -1)
# aggregate
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
# merge heads
out = self.merge_heads(out)
# combine heads
out = self.to_out(out)
if not return_kv_cache:

View File

@ -26,6 +26,7 @@ classifiers=[
]
dependencies = [
"einx>=0.3.0",
"einops>=0.8.1",
"torch>=2.4"
]