they employ two stability measures, qk rmsnorm and softclamping of attention logits
This commit is contained in:
parent
e8678364ba
commit
2e92c0121a
@ -5,11 +5,12 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module, ModuleList, Sequential, Linear, RMSNorm, Identity
|
from torch.nn import Module, ModuleList, Parameter, Sequential, Linear, RMSNorm, Identity
|
||||||
from torch import cat, stack, tensor, Tensor, is_tensor
|
from torch import cat, stack, tensor, Tensor, is_tensor
|
||||||
|
|
||||||
# ein related
|
# ein related
|
||||||
|
|
||||||
|
import einx
|
||||||
from einops import einsum, rearrange, repeat, reduce
|
from einops import einsum, rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
@ -25,6 +26,32 @@ def exists(v):
|
|||||||
def default(v, d):
|
def default(v, d):
|
||||||
return v if exists(v) else d
|
return v if exists(v) else d
|
||||||
|
|
||||||
|
def l2norm(t):
|
||||||
|
return F.normalize(t, dim = -1, p = 2)
|
||||||
|
|
||||||
|
def softclamp(t, value = 50.):
|
||||||
|
return (t / value).tanh() * value
|
||||||
|
|
||||||
|
# multi-head rmsnorm
|
||||||
|
|
||||||
|
class MultiHeadRMSNorm(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_head,
|
||||||
|
heads = 8
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head ** 0.5
|
||||||
|
self.gamma = Parameter(torch.zeros(heads, dim_head)) # weight decay friendly
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x
|
||||||
|
):
|
||||||
|
normed = l2norm(x)
|
||||||
|
scale = (self.gamma + 1.) * self.scale
|
||||||
|
return einx.multiply('... h n d, h d', normed, scale)
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
class Attention(Module):
|
class Attention(Module):
|
||||||
@ -33,6 +60,7 @@ class Attention(Module):
|
|||||||
dim,
|
dim,
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
|
softclamp_value = 50.,
|
||||||
pre_rmsnorm = True
|
pre_rmsnorm = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -47,6 +75,13 @@ class Attention(Module):
|
|||||||
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
||||||
self.to_out = LinearNoBias(dim_inner, dim)
|
self.to_out = LinearNoBias(dim_inner, dim)
|
||||||
|
|
||||||
|
# stability related
|
||||||
|
|
||||||
|
self.q_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||||
|
self.k_heads_rmsnorm = MultiHeadRMSNorm(dim_head, heads = heads)
|
||||||
|
|
||||||
|
self.softclamp_value = softclamp_value
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
tokens,
|
tokens,
|
||||||
@ -57,23 +92,47 @@ class Attention(Module):
|
|||||||
|
|
||||||
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
|
||||||
q, k, v = map(self.split_heads, (q, k, v))
|
q, k, v = map(self.split_heads, (q, k, v))
|
||||||
|
|
||||||
|
# qk rmsnorm
|
||||||
|
|
||||||
|
q = self.q_heads_rmsnorm(q)
|
||||||
|
k = self.k_heads_rmsnorm(k)
|
||||||
|
|
||||||
|
# caching
|
||||||
|
|
||||||
if exists(kv_cache):
|
if exists(kv_cache):
|
||||||
ck, cv = kv_cache
|
ck, cv = kv_cache
|
||||||
k = cat((ck, k), dim = -2)
|
k = cat((ck, k), dim = -2)
|
||||||
v = cat((cv, v), dim = -2)
|
v = cat((cv, v), dim = -2)
|
||||||
|
|
||||||
q = q * self.scale
|
# similarity
|
||||||
|
|
||||||
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
||||||
|
|
||||||
|
# softclamping a la gemma 3
|
||||||
|
|
||||||
|
if exists(self.softclamp_value):
|
||||||
|
sim = softclamp(sim, self.softclamp_value)
|
||||||
|
|
||||||
|
# scale and attention
|
||||||
|
|
||||||
|
sim = sim * self.scale
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1)
|
attn = sim.softmax(dim = -1)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
|
||||||
out = self.merge_heads(out)
|
out = self.merge_heads(out)
|
||||||
|
|
||||||
|
# combine heads
|
||||||
|
|
||||||
out = self.to_out(out)
|
out = self.to_out(out)
|
||||||
|
|
||||||
if not return_kv_cache:
|
if not return_kv_cache:
|
||||||
|
|||||||
@ -26,6 +26,7 @@ classifiers=[
|
|||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"einx>=0.3.0",
|
||||||
"einops>=0.8.1",
|
"einops>=0.8.1",
|
||||||
"torch>=2.4"
|
"torch>=2.4"
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user